You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by al...@apache.org on 2019/08/02 13:18:43 UTC

[cassandra] 02/09: Make repair coordination less expensive by moving MerkleTrees off heap

This is an automated email from the ASF dual-hosted git repository.

aleksey pushed a commit to branch 15202-4.0
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit be64c1c26a879bad97390ad3d38e2ed03806c8c8
Author: Jeff Jirsa <jj...@apple.com>
AuthorDate: Sat Mar 16 17:30:54 2019 +0000

    Make repair coordination less expensive by moving MerkleTrees off heap
    
    patch by Aleksey Yeschenko and Jeff Jirsa; reviewed by Benedict Elliott
    Smith and Marcus Eriksson for CASSANDRA-15202
    
    Co-authored-by: Aleksey Yeschenko <al...@apache.org>
    Co-authored-by: Jeff Jirsa <jj...@apple.com>
---
 src/java/org/apache/cassandra/config/Config.java   |    2 +
 .../cassandra/config/DatabaseDescriptor.java       |   11 +
 .../cassandra/dht/ByteOrderedPartitioner.java      |    6 +
 .../org/apache/cassandra/dht/IPartitioner.java     |    5 +
 .../apache/cassandra/dht/Murmur3Partitioner.java   |   33 +
 .../apache/cassandra/dht/RandomPartitioner.java    |   32 +
 src/java/org/apache/cassandra/dht/Token.java       |   40 +-
 .../org/apache/cassandra/repair/RepairJob.java     |    3 +
 .../org/apache/cassandra/repair/Validator.java     |   73 +-
 .../repair/messages/ValidationComplete.java        |    9 +
 .../cassandra/service/ActiveRepairService.java     |   10 +
 .../org/apache/cassandra/utils/ByteBufferUtil.java |   30 +-
 .../org/apache/cassandra/utils/FBUtilities.java    |   16 +
 .../apache/cassandra/utils/FastByteOperations.java |   46 +-
 .../org/apache/cassandra/utils/MerkleTree.java     | 1406 +++++++++++++-------
 .../org/apache/cassandra/utils/MerkleTrees.java    |   63 +-
 .../apache/cassandra/repair/LocalSyncTaskTest.java |    4 -
 .../org/apache/cassandra/repair/RepairJobTest.java |    4 -
 .../org/apache/cassandra/repair/ValidatorTest.java |   18 +-
 .../repair/asymmetric/DifferenceHolderTest.java    |    6 +-
 .../org/apache/cassandra/utils/MerkleTreeTest.java |  127 +-
 .../apache/cassandra/utils/MerkleTreesTest.java    |   99 +-
 22 files changed, 1344 insertions(+), 699 deletions(-)

diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java
index 34a5ce8..b56419d 100644
--- a/src/java/org/apache/cassandra/config/Config.java
+++ b/src/java/org/apache/cassandra/config/Config.java
@@ -130,6 +130,8 @@ public class Config
     public volatile Integer repair_session_max_tree_depth = null;
     public volatile Integer repair_session_space_in_mb = null;
 
+    public volatile boolean prefer_offheap_merkle_trees = true;
+
     public int storage_port = 7000;
     public int ssl_storage_port = 7001;
     public String listen_address;
diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
index 0166c5f..49761da 100644
--- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
+++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
@@ -2907,4 +2907,15 @@ public class DatabaseDescriptor
     {
         return strictRuntimeChecks;
     }
+
+    public static boolean getOffheapMerkleTreesEnabled()
+    {
+        return conf.prefer_offheap_merkle_trees;
+    }
+
+    public static void setOffheapMerkleTreesEnabled(boolean value)
+    {
+        logger.info("Setting prefer_offheap_merkle_trees to {}", value);
+        conf.prefer_offheap_merkle_trees = value;
+    }
 }
diff --git a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java
index 08088f7..a6314dc 100644
--- a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java
+++ b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java
@@ -234,6 +234,12 @@ public class ByteOrderedPartitioner implements IPartitioner
             return new BytesToken(bytes);
         }
 
+        @Override
+        public int byteSize(Token token)
+        {
+            return ((BytesToken) token).token.length;
+        }
+
         public String toString(Token token)
         {
             BytesToken bytesToken = (BytesToken) token;
diff --git a/src/java/org/apache/cassandra/dht/IPartitioner.java b/src/java/org/apache/cassandra/dht/IPartitioner.java
index f433f20..ef8ced2 100644
--- a/src/java/org/apache/cassandra/dht/IPartitioner.java
+++ b/src/java/org/apache/cassandra/dht/IPartitioner.java
@@ -135,4 +135,9 @@ public interface IPartitioner
     {
         return Optional.empty();
     }
+
+    default public int getMaxTokenSize()
+    {
+        return Integer.MIN_VALUE;
+    }
 }
diff --git a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java
index 0f922e3..52d0efb 100644
--- a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java
+++ b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java
@@ -17,6 +17,7 @@
  */
 package org.apache.cassandra.dht;
 
+import java.io.IOException;
 import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.nio.ByteBuffer;
@@ -25,10 +26,12 @@ import java.util.concurrent.ThreadLocalRandom;
 
 import org.apache.cassandra.db.DecoratedKey;
 import org.apache.cassandra.db.PreHashedDecoratedKey;
+import org.apache.cassandra.db.TypeSizes;
 import org.apache.cassandra.db.marshal.AbstractType;
 import org.apache.cassandra.db.marshal.PartitionerDefinedOrder;
 import org.apache.cassandra.db.marshal.LongType;
 import org.apache.cassandra.exceptions.ConfigurationException;
+import org.apache.cassandra.io.util.DataOutputPlus;
 import org.apache.cassandra.utils.ByteBufferUtil;
 import org.apache.cassandra.utils.MurmurHash;
 import org.apache.cassandra.utils.ObjectSizes;
@@ -42,6 +45,7 @@ public class Murmur3Partitioner implements IPartitioner
 {
     public static final LongToken MINIMUM = new LongToken(Long.MIN_VALUE);
     public static final long MAXIMUM = Long.MAX_VALUE;
+    private static final int MAXIMUM_TOKEN_SIZE = TypeSizes.sizeof(MAXIMUM);
 
     private static final int HEAP_SIZE = (int) ObjectSizes.measureDeep(MINIMUM);
 
@@ -224,6 +228,11 @@ public class Murmur3Partitioner implements IPartitioner
         return new LongToken(normalize(hash[0]));
     }
 
+    public int getMaxTokenSize()
+    {
+        return MAXIMUM_TOKEN_SIZE;
+    }
+
     private long[] getHash(ByteBuffer key)
     {
         long[] hash = new long[2];
@@ -300,11 +309,35 @@ public class Murmur3Partitioner implements IPartitioner
             return ByteBufferUtil.bytes(longToken.token);
         }
 
+        @Override
+        public void serialize(Token token, DataOutputPlus out) throws IOException
+        {
+            out.writeLong(((LongToken) token).token);
+        }
+
+        @Override
+        public void serialize(Token token, ByteBuffer out)
+        {
+            out.putLong(((LongToken) token).token);
+        }
+
+        @Override
+        public int byteSize(Token token)
+        {
+            return 8;
+        }
+
         public Token fromByteArray(ByteBuffer bytes)
         {
             return new LongToken(ByteBufferUtil.toLong(bytes));
         }
 
+        @Override
+        public Token fromByteBuffer(ByteBuffer bytes, int position, int length)
+        {
+            return new LongToken(bytes.getLong(position));
+        }
+
         public String toString(Token token)
         {
             return token.toString();
diff --git a/src/java/org/apache/cassandra/dht/RandomPartitioner.java b/src/java/org/apache/cassandra/dht/RandomPartitioner.java
index 4e63475..0457a89 100644
--- a/src/java/org/apache/cassandra/dht/RandomPartitioner.java
+++ b/src/java/org/apache/cassandra/dht/RandomPartitioner.java
@@ -17,6 +17,7 @@
  */
 package org.apache.cassandra.dht;
 
+import java.io.IOException;
 import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.nio.ByteBuffer;
@@ -31,6 +32,7 @@ import org.apache.cassandra.db.DecoratedKey;
 import org.apache.cassandra.db.marshal.AbstractType;
 import org.apache.cassandra.db.marshal.IntegerType;
 import org.apache.cassandra.db.marshal.PartitionerDefinedOrder;
+import org.apache.cassandra.io.util.DataOutputPlus;
 import org.apache.cassandra.utils.ByteBufferUtil;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.GuidGenerator;
@@ -46,6 +48,7 @@ public class RandomPartitioner implements IPartitioner
     public static final BigInteger ZERO = new BigInteger("0");
     public static final BigIntegerToken MINIMUM = new BigIntegerToken("-1");
     public static final BigInteger MAXIMUM = new BigInteger("2").pow(127);
+    public static final int MAXIMUM_TOKEN_SIZE = MAXIMUM.bitLength() / 8 + 1;
 
     /**
      * Maintain a separate threadlocal message digest, exclusively for token hashing. This is necessary because
@@ -162,11 +165,35 @@ public class RandomPartitioner implements IPartitioner
             return ByteBuffer.wrap(bigIntegerToken.token.toByteArray());
         }
 
+        @Override
+        public void serialize(Token token, DataOutputPlus out) throws IOException
+        {
+            out.write(((BigIntegerToken) token).token.toByteArray());
+        }
+
+        @Override
+        public void serialize(Token token, ByteBuffer out)
+        {
+            out.put(((BigIntegerToken) token).token.toByteArray());
+        }
+
+        @Override
+        public int byteSize(Token token)
+        {
+            return ((BigIntegerToken) token).token.bitLength() / 8 + 1;
+        }
+
         public Token fromByteArray(ByteBuffer bytes)
         {
             return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes)));
         }
 
+        @Override
+        public Token fromByteBuffer(ByteBuffer bytes, int position, int length)
+        {
+            return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes, position, length)));
+        }
+
         public String toString(Token token)
         {
             BigIntegerToken bigIntegerToken = (BigIntegerToken) token;
@@ -252,6 +279,11 @@ public class RandomPartitioner implements IPartitioner
         return new BigIntegerToken(hashToBigInteger(key));
     }
 
+    public int getMaxTokenSize()
+    {
+        return MAXIMUM_TOKEN_SIZE;
+    }
+
     public Map<Token, Float> describeOwnership(List<Token> sortedTokens)
     {
         Map<Token, Float> ownerships = new HashMap<Token, Float>();
diff --git a/src/java/org/apache/cassandra/dht/Token.java b/src/java/org/apache/cassandra/dht/Token.java
index 20b45ef..ccb66fd 100644
--- a/src/java/org/apache/cassandra/dht/Token.java
+++ b/src/java/org/apache/cassandra/dht/Token.java
@@ -26,7 +26,6 @@ import org.apache.cassandra.db.PartitionPosition;
 import org.apache.cassandra.db.TypeSizes;
 import org.apache.cassandra.exceptions.ConfigurationException;
 import org.apache.cassandra.io.util.DataOutputPlus;
-import org.apache.cassandra.utils.ByteBufferUtil;
 
 public abstract class Token implements RingPosition<Token>, Serializable
 {
@@ -40,8 +39,30 @@ public abstract class Token implements RingPosition<Token>, Serializable
         public abstract Token fromByteArray(ByteBuffer bytes);
         public abstract String toString(Token token); // serialize as string, not necessarily human-readable
         public abstract Token fromString(String string); // deserialize
-
         public abstract void validate(String token) throws ConfigurationException;
+
+        public void serialize(Token token, DataOutputPlus out) throws IOException
+        {
+            out.write(toByteArray(token));
+        }
+
+        public void serialize(Token token, ByteBuffer out) throws IOException
+        {
+            out.put(toByteArray(token));
+        }
+
+        public Token fromByteBuffer(ByteBuffer bytes, int position, int length)
+        {
+            bytes = bytes.duplicate();
+            bytes.position(position)
+                 .limit(position + length);
+            return fromByteArray(bytes);
+        }
+
+        public int byteSize(Token token)
+        {
+            return toByteArray(token).remaining();
+        }
     }
 
     public static class TokenSerializer implements IPartitionerDependentSerializer<Token>
@@ -49,23 +70,28 @@ public abstract class Token implements RingPosition<Token>, Serializable
         public void serialize(Token token, DataOutputPlus out, int version) throws IOException
         {
             IPartitioner p = token.getPartitioner();
-            ByteBuffer b = p.getTokenFactory().toByteArray(token);
-            ByteBufferUtil.writeWithLength(b, out);
+            out.writeInt(p.getTokenFactory().byteSize(token));
+            p.getTokenFactory().serialize(token, out);
         }
 
         public Token deserialize(DataInput in, IPartitioner p, int version) throws IOException
         {
-            int size = in.readInt();
+            int size = deserializeSize(in);
             byte[] bytes = new byte[size];
             in.readFully(bytes);
             return p.getTokenFactory().fromByteArray(ByteBuffer.wrap(bytes));
         }
 
+        public int deserializeSize(DataInput in) throws IOException
+        {
+            return in.readInt();
+        }
+
         public long serializedSize(Token object, int version)
         {
             IPartitioner p = object.getPartitioner();
-            ByteBuffer b = p.getTokenFactory().toByteArray(object);
-            return TypeSizes.sizeof(b.remaining()) + b.remaining();
+            int byteSize = p.getTokenFactory().byteSize(object);
+            return TypeSizes.sizeof(byteSize) + byteSize;
         }
     }
 
diff --git a/src/java/org/apache/cassandra/repair/RepairJob.java b/src/java/org/apache/cassandra/repair/RepairJob.java
index a67aac0..f682bfb 100644
--- a/src/java/org/apache/cassandra/repair/RepairJob.java
+++ b/src/java/org/apache/cassandra/repair/RepairJob.java
@@ -236,7 +236,10 @@ public class RepairJob extends AbstractFuture<RepairResult> implements Runnable
                 }
                 syncTasks.add(task);
             }
+            trees.get(i).trees.release();
         }
+        trees.get(trees.size() - 1).trees.release();
+
         return syncTasks;
     }
 
diff --git a/src/java/org/apache/cassandra/repair/Validator.java b/src/java/org/apache/cassandra/repair/Validator.java
index c8becaa..9a89fa6 100644
--- a/src/java/org/apache/cassandra/repair/Validator.java
+++ b/src/java/org/apache/cassandra/repair/Validator.java
@@ -19,6 +19,7 @@ package org.apache.cassandra.repair;
 
 import java.nio.ByteBuffer;
 import java.nio.charset.Charset;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
@@ -29,7 +30,6 @@ import com.google.common.hash.HashCode;
 import com.google.common.hash.HashFunction;
 import com.google.common.hash.Hasher;
 import com.google.common.hash.Hashing;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -46,6 +46,7 @@ import org.apache.cassandra.net.Message;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.repair.messages.ValidationComplete;
 import org.apache.cassandra.streaming.PreviewKind;
+import org.apache.cassandra.service.ActiveRepairService;
 import org.apache.cassandra.tracing.Tracing;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.MerkleTree;
@@ -149,7 +150,7 @@ public class Validator implements Runnable
             }
         }
         logger.debug("Prepared AEService trees of size {} for {}", trees.size(), desc);
-        ranges = tree.invalids();
+        ranges = tree.rangeIterator();
     }
 
     /**
@@ -172,7 +173,7 @@ public class Validator implements Runnable
         if (!findCorrectRange(lastKey.getToken()))
         {
             // add the empty hash, and move to the next range
-            ranges = trees.invalids();
+            ranges = trees.rangeIterator();
             findCorrectRange(lastKey.getToken());
         }
 
@@ -368,9 +369,7 @@ public class Validator implements Runnable
      */
     public void complete()
     {
-        completeTree();
-
-        StageManager.getStage(Stage.ANTI_ENTROPY).execute(this);
+        assert ranges != null : "Validator was not prepared()";
 
         if (logger.isDebugEnabled())
         {
@@ -380,20 +379,8 @@ public class Validator implements Runnable
             logger.debug("Validated {} partitions for {}.  Partition sizes are:", validated, desc.sessionId);
             trees.logRowSizePerLeaf(logger);
         }
-    }
-
-    @VisibleForTesting
-    public void completeTree()
-    {
-        assert ranges != null : "Validator was not prepared()";
-
-        ranges = trees.invalids();
 
-        while (ranges.hasNext())
-        {
-            range = ranges.next();
-            range.ensureHashInitialised();
-        }
+        StageManager.getStage(Stage.ANTI_ENTROPY).execute(this);
     }
 
     /**
@@ -404,8 +391,7 @@ public class Validator implements Runnable
     public void fail()
     {
         logger.error("Failed creating a merkle tree for {}, {} (see log for details)", desc, initiator);
-        // send fail message only to nodes >= version 2.0
-        MessagingService.instance().send(Message.out(REPAIR_REQ, new ValidationComplete(desc)), initiator);
+        respond(new ValidationComplete(desc));
     }
 
     /**
@@ -413,12 +399,51 @@ public class Validator implements Runnable
      */
     public void run()
     {
-        // respond to the request that triggered this validation
-        if (!initiator.equals(FBUtilities.getBroadcastAddressAndPort()))
+        if (initiatorIsRemote())
         {
             logger.info("{} Sending completed merkle tree to {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily);
             Tracing.traceRepair("Sending completed merkle tree to {} for {}.{}", initiator, desc.keyspace, desc.columnFamily);
         }
-        MessagingService.instance().send(Message.out(REPAIR_REQ, new ValidationComplete(desc, trees)), initiator);
+        else
+        {
+            logger.info("{} Local completed merkle tree for {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily);
+            Tracing.traceRepair("Local completed merkle tree for {} for {}.{}", initiator, desc.keyspace, desc.columnFamily);
+
+        }
+        respond(new ValidationComplete(desc, trees));
+    }
+
+    private boolean initiatorIsRemote()
+    {
+        return !FBUtilities.getBroadcastAddressAndPort().equals(initiator);
+    }
+
+    private void respond(ValidationComplete response)
+    {
+        if (initiatorIsRemote())
+        {
+            MessagingService.instance().send(Message.out(REPAIR_REQ, response), initiator);
+            return;
+        }
+
+        /*
+         * For local initiators, DO NOT send the message to self over loopback. This is a wasted ser/de loop
+         * and a ton of garbage. Instead, move the trees off heap and invoke message handler. We could do it
+         * directly, since this method will only be called from {@code Stage.ENTI_ENTROPY}, but we do instead
+         * execute a {@code Runnable} on the stage - in case that assumption ever changes by accident.
+         */
+        StageManager.getStage(Stage.ANTI_ENTROPY).execute(() ->
+        {
+            ValidationComplete movedResponse = response;
+            try
+            {
+                movedResponse = response.tryMoveOffHeap();
+            }
+            catch (IOException e)
+            {
+                logger.error("Failed to move local merkle tree for {} off heap", desc, e);
+            }
+            ActiveRepairService.instance.handleMessage(initiator, movedResponse);
+        });
     }
 }
diff --git a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java b/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java
index 704bffb..b8aa736 100644
--- a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java
+++ b/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java
@@ -56,6 +56,15 @@ public class ValidationComplete extends RepairMessage
         return trees != null;
     }
 
+    /**
+     * @return a new {@link ValidationComplete} instance with all trees moved off heap, or {@code this}
+     * if it's a failure response.
+     */
+    public ValidationComplete tryMoveOffHeap() throws IOException
+    {
+        return trees == null ? this : new ValidationComplete(desc, trees.tryMoveOffHeap());
+    }
+
     @Override
     public boolean equals(Object o)
     {
diff --git a/src/java/org/apache/cassandra/service/ActiveRepairService.java b/src/java/org/apache/cassandra/service/ActiveRepairService.java
index abfd6d9..689771c 100644
--- a/src/java/org/apache/cassandra/service/ActiveRepairService.java
+++ b/src/java/org/apache/cassandra/service/ActiveRepairService.java
@@ -248,6 +248,16 @@ public class ActiveRepairService implements IEndpointStateChangeSubscriber, IFai
         return session;
     }
 
+    public boolean getOffheapMerkleTreesEnabled()
+    {
+        return DatabaseDescriptor.getOffheapMerkleTreesEnabled();
+    }
+
+    public void setOffheapMerkleTreesEnabled(boolean enabled)
+    {
+        DatabaseDescriptor.setOffheapMerkleTreesEnabled(enabled);
+    }
+
     private <T extends AbstractFuture &
                IEndpointStateChangeSubscriber &
                IFailureDetectionEventListener> void registerOnFdAndGossip(final T task)
diff --git a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java
index 788300c..518436e 100644
--- a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java
+++ b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java
@@ -101,6 +101,16 @@ public class ByteBufferUtil
         return FastByteOperations.compareUnsigned(o1, o2, 0, o2.length);
     }
 
+    public static int compare(ByteBuffer o1, int s1, int l1, byte[] o2)
+    {
+        return FastByteOperations.compareUnsigned(o1, s1, l1, o2, 0, o2.length);
+    }
+
+    public static int compare(byte[] o1, ByteBuffer o2, int s2, int l2)
+    {
+        return FastByteOperations.compareUnsigned(o1, 0, o1.length, o2, s2, l2);
+    }
+
     /**
      * Decode a String representation.
      * This method assumes that the encoding charset is UTF_8.
@@ -161,16 +171,25 @@ public class ByteBufferUtil
      */
     public static byte[] getArray(ByteBuffer buffer)
     {
-        int length = buffer.remaining();
+        return getArray(buffer, buffer.position(), buffer.remaining());
+    }
+
+    /**
+     * You should almost never use this.  Instead, use the write* methods to avoid copies.
+     */
+    public static byte[] getArray(ByteBuffer buffer, int position, int length)
+    {
         if (buffer.hasArray())
         {
-            int boff = buffer.arrayOffset() + buffer.position();
+            int boff = buffer.arrayOffset() + position;
             return Arrays.copyOfRange(buffer.array(), boff, boff + length);
         }
+
         // else, DirectByteBuffer.get() is the fastest route
         byte[] bytes = new byte[length];
-        buffer.duplicate().get(bytes);
-
+        ByteBuffer dup = buffer.duplicate();
+        dup.position(position).limit(position + length);
+        dup.get(bytes);
         return bytes;
     }
 
@@ -631,6 +650,7 @@ public class ByteBufferUtil
 
         assert bytes1.limit() >= offset1 + length : "The first byte array isn't long enough for the specified offset and length.";
         assert bytes2.limit() >= offset2 + length : "The second byte array isn't long enough for the specified offset and length.";
+
         for (int i = 0; i < length; i++)
         {
             byte byte1 = bytes1.get(offset1 + i);
@@ -669,7 +689,7 @@ public class ByteBufferUtil
         return buf.capacity() > buf.remaining() || !buf.hasArray() ? ByteBuffer.wrap(getArray(buf)) : buf;
     }
 
-    // Doesn't change bb position
+    // doesn't change bb position
     public static int getShortLength(ByteBuffer bb, int position)
     {
         int length = (bb.get(position) & 0xFF) << 8;
diff --git a/src/java/org/apache/cassandra/utils/FBUtilities.java b/src/java/org/apache/cassandra/utils/FBUtilities.java
index c37dcca..0f7b2ca 100644
--- a/src/java/org/apache/cassandra/utils/FBUtilities.java
+++ b/src/java/org/apache/cassandra/utils/FBUtilities.java
@@ -289,6 +289,22 @@ public class FBUtilities
         return out;
     }
 
+    /**
+     * Bitwise XOR of the inputs, in place on the left array
+     *
+     * Assumes inputs are same length
+     */
+    static void xorOntoLeft(byte[] left, byte[] right)
+    {
+        if (left == null || right == null)
+            return;
+
+        assert left.length == right.length;
+
+        for (int i = 0; i < left.length; i++)
+            left[i] = (byte) ((left[i] & 0xFF) ^ (right[i] & 0xFF));
+    }
+
     public static void sortSampledKeys(List<DecoratedKey> keys, Range<Token> range)
     {
         if (range.left.compareTo(range.right) >= 0)
diff --git a/src/java/org/apache/cassandra/utils/FastByteOperations.java b/src/java/org/apache/cassandra/utils/FastByteOperations.java
index 6581736..060dee5 100644
--- a/src/java/org/apache/cassandra/utils/FastByteOperations.java
+++ b/src/java/org/apache/cassandra/utils/FastByteOperations.java
@@ -55,6 +55,16 @@ public class FastByteOperations
         return -BestHolder.BEST.compare(b2, b1, s1, l1);
     }
 
+    public static int compareUnsigned(ByteBuffer b1, int s1, int l1, byte[] b2, int s2, int l2)
+    {
+        return BestHolder.BEST.compare(b1, s1, l1, b2, s2, l2);
+    }
+
+    public static int compareUnsigned(byte[] b1, int s1, int l1, ByteBuffer b2, int s2, int l2)
+    {
+        return -BestHolder.BEST.compare(b2, s2, l2, b1, s1, l1);
+    }
+
     public static int compareUnsigned(ByteBuffer b1, ByteBuffer b2)
     {
         return BestHolder.BEST.compare(b1, b2);
@@ -77,6 +87,8 @@ public class FastByteOperations
 
         abstract public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2);
 
+        abstract public int compare(ByteBuffer buffer1, int offset1, int length1, byte[] buffer2, int offset2, int length2);
+
         abstract public int compare(ByteBuffer buffer1, ByteBuffer buffer2);
 
         abstract public void copy(ByteBuffer src, int srcPosition, byte[] trg, int trgPosition, int length);
@@ -187,25 +199,24 @@ public class FastByteOperations
 
         public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2)
         {
+            return compare(buffer1, buffer1.position(), buffer1.remaining(), buffer2, offset2, length2);
+        }
+
+        public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2)
+        {
             Object obj1;
             long offset1;
             if (buffer1.hasArray())
             {
                 obj1 = buffer1.array();
-                offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset();
+                offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset() + position1;
             }
             else
             {
                 obj1 = null;
-                offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET);
-            }
-            int length1;
-            {
-                int position = buffer1.position();
-                int limit = buffer1.limit();
-                length1 = limit - position;
-                offset1 += position;
+                offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET) + position1;
             }
+
             return compareTo(obj1, offset1, length1, buffer2, BYTE_ARRAY_BASE_OFFSET + offset2, length2);
         }
 
@@ -397,11 +408,28 @@ public class FastByteOperations
             return length1 - length2;
         }
 
+        public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2)
+        {
+            if (buffer1.hasArray())
+                return compare(buffer1.array(), buffer1.arrayOffset() + position1, length1, buffer2, offset2, length2);
+
+            if (position1 != buffer1.position())
+            {
+                buffer1 = buffer1.duplicate();
+                buffer1.position(position1);
+            }
+
+            return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2));
+        }
+
         public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2)
         {
             if (buffer1.hasArray())
+            {
                 return compare(buffer1.array(), buffer1.arrayOffset() + buffer1.position(), buffer1.remaining(),
                                buffer2, offset2, length2);
+            }
+
             return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2));
         }
 
diff --git a/src/java/org/apache/cassandra/utils/MerkleTree.java b/src/java/org/apache/cassandra/utils/MerkleTree.java
index d131ff5..9d9eadb 100644
--- a/src/java/org/apache/cassandra/utils/MerkleTree.java
+++ b/src/java/org/apache/cassandra/utils/MerkleTree.java
@@ -19,25 +19,33 @@ package org.apache.cassandra.utils;
 
 import java.io.DataInput;
 import java.io.IOException;
-import java.io.Serializable;
+import java.nio.ByteBuffer;
 import java.util.*;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.PeekingIterator;
+import com.google.common.primitives.Ints;
+import com.google.common.primitives.Shorts;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.cassandra.db.TypeSizes;
+import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.dht.IPartitioner;
-import org.apache.cassandra.dht.IPartitionerDependentSerializer;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.dht.RandomPartitioner;
 import org.apache.cassandra.dht.Range;
 import org.apache.cassandra.dht.Token;
 import org.apache.cassandra.exceptions.ConfigurationException;
-import org.apache.cassandra.io.IVersionedSerializer;
 import org.apache.cassandra.io.util.DataInputPlus;
 import org.apache.cassandra.io.util.DataOutputPlus;
+import org.apache.cassandra.io.util.FileUtils;
+import org.apache.cassandra.utils.concurrent.Ref;
+import org.apache.cassandra.utils.memory.MemoryUtil;
+
+import static org.apache.cassandra.db.TypeSizes.sizeof;
+import static org.apache.cassandra.utils.ByteBufferUtil.compare;
 
 /**
  * A MerkleTree implemented as a binary tree.
@@ -59,84 +67,51 @@ import org.apache.cassandra.io.util.DataOutputPlus;
  * If two MerkleTrees have the same hashdepth, they represent a perfect tree
  * of the same depth, and can always be compared, regardless of size or splits.
  */
-public class MerkleTree implements Serializable
+public class MerkleTree
 {
-    private static Logger logger = LoggerFactory.getLogger(MerkleTree.class);
+    private static final Logger logger = LoggerFactory.getLogger(MerkleTree.class);
+
+    private static final int HASH_SIZE = 32; // SHA-256 = 32 bytes.
+    private static final byte[] EMPTY_HASH = new byte[HASH_SIZE];
+
+    /*
+     * Thread-local byte array, large enough to host SHA256 bytes or MM3/Random partitoners' tokens
+     */
+    private static final ThreadLocal<byte[]> byteArray = ThreadLocal.withInitial(() -> new byte[HASH_SIZE]);
 
-    public static final MerkleTreeSerializer serializer = new MerkleTreeSerializer();
-    private static final long serialVersionUID = 2L;
+    private static byte[] getTempArray(int minimumSize)
+    {
+        return minimumSize <= HASH_SIZE ? byteArray.get() : new byte[minimumSize];
+    }
 
     public static final byte RECOMMENDED_DEPTH = Byte.MAX_VALUE - 1;
 
-    public static final int CONSISTENT = 0;
-    public static final int FULLY_INCONSISTENT = 1;
-    public static final int PARTIALLY_INCONSISTENT = 2;
-    private static final byte[] EMPTY_HASH = new byte[0];
+    @SuppressWarnings("WeakerAccess")
+    static final int CONSISTENT = 0;
+    static final int FULLY_INCONSISTENT = 1;
+    @SuppressWarnings("WeakerAccess")
+    static final int PARTIALLY_INCONSISTENT = 2;
 
-    public final byte hashdepth;
+    private final int hashdepth;
 
     /** The top level range that this MerkleTree covers. */
-    public final Range<Token> fullRange;
+    final Range<Token> fullRange;
     private final IPartitioner partitioner;
 
     private long maxsize;
     private long size;
-    private Hashable root;
+    private Node root;
 
-    public static class MerkleTreeSerializer implements IVersionedSerializer<MerkleTree>
+    /**
+     * @param partitioner The partitioner in use.
+     * @param range the range this tree covers
+     * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the %
+     *        of the key space covered by each subrange of a fully populated tree.
+     * @param maxsize The maximum number of subranges in the tree.
+     */
+    public MerkleTree(IPartitioner partitioner, Range<Token> range, int hashdepth, long maxsize)
     {
-        public void serialize(MerkleTree mt, DataOutputPlus out, int version) throws IOException
-        {
-            out.writeByte(mt.hashdepth);
-            out.writeLong(mt.maxsize);
-            out.writeLong(mt.size);
-            out.writeUTF(mt.partitioner.getClass().getCanonicalName());
-            // full range
-            Token.serializer.serialize(mt.fullRange.left, out, version);
-            Token.serializer.serialize(mt.fullRange.right, out, version);
-            Hashable.serializer.serialize(mt.root, out, version);
-        }
-
-        public MerkleTree deserialize(DataInputPlus in, int version) throws IOException
-        {
-            byte hashdepth = in.readByte();
-            long maxsize = in.readLong();
-            long size = in.readLong();
-            IPartitioner partitioner;
-            try
-            {
-                partitioner = FBUtilities.newPartitioner(in.readUTF());
-            }
-            catch (ConfigurationException e)
-            {
-                throw new IOException(e);
-            }
-
-            // full range
-            Token left = Token.serializer.deserialize(in, partitioner, version);
-            Token right = Token.serializer.deserialize(in, partitioner, version);
-            Range<Token> fullRange = new Range<>(left, right);
-
-            MerkleTree mt = new MerkleTree(partitioner, fullRange, hashdepth, maxsize);
-            mt.size = size;
-            mt.root = Hashable.serializer.deserialize(in, partitioner, version);
-            return mt;
-        }
-
-        public long serializedSize(MerkleTree mt, int version)
-        {
-            long size = 1 // mt.hashdepth
-                 + TypeSizes.sizeof(mt.maxsize)
-                 + TypeSizes.sizeof(mt.size)
-                 + TypeSizes.sizeof(mt.partitioner.getClass().getCanonicalName());
-
-            // full range
-            size += Token.serializer.serializedSize(mt.fullRange.left, version);
-            size += Token.serializer.serializedSize(mt.fullRange.right, version);
-
-            size += Hashable.serializer.serializedSize(mt.root, version);
-            return size;
-        }
+        this(new OnHeapLeaf(), partitioner, range, hashdepth, maxsize, 1);
     }
 
     /**
@@ -145,60 +120,56 @@ public class MerkleTree implements Serializable
      * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the %
      *        of the key space covered by each subrange of a fully populated tree.
      * @param maxsize The maximum number of subranges in the tree.
+     * @param size The size of the tree. Typically 1, unless deserilized from an existing tree
      */
-    public MerkleTree(IPartitioner partitioner, Range<Token> range, byte hashdepth, long maxsize)
+    private MerkleTree(Node root, IPartitioner partitioner, Range<Token> range, int hashdepth, long maxsize, long size)
     {
         assert hashdepth < Byte.MAX_VALUE;
+
+        this.root = root;
         this.fullRange = Preconditions.checkNotNull(range);
         this.partitioner = Preconditions.checkNotNull(partitioner);
         this.hashdepth = hashdepth;
         this.maxsize = maxsize;
-
-        size = 1;
-        root = new Leaf(null);
-    }
-
-
-    static byte inc(byte in)
-    {
-        assert in < Byte.MAX_VALUE;
-        return (byte)(in + 1);
+        this.size = size;
     }
 
     /**
      * Initializes this tree by splitting it until hashdepth is reached,
      * or until an additional level of splits would violate maxsize.
      *
-     * NB: Replaces all nodes in the tree.
+     * NB: Replaces all nodes in the tree, and always builds on the heap
      */
     public void init()
     {
         // determine the depth to which we can safely split the tree
-        byte sizedepth = (byte)(Math.log10(maxsize) / Math.log10(2));
-        byte depth = (byte)Math.min(sizedepth, hashdepth);
+        int sizedepth = (int) (Math.log10(maxsize) / Math.log10(2));
+        int depth = Math.min(sizedepth, hashdepth);
+
+        root = initHelper(fullRange.left, fullRange.right, 0, depth);
+        size = (long) Math.pow(2, depth);
+    }
 
-        root = initHelper(fullRange.left, fullRange.right, (byte)0, depth);
-        size = (long)Math.pow(2, depth);
+    public void release()
+    {
+        if (root instanceof OffHeapNode)
+            ((OffHeapNode) root).release();
+        root = null;
     }
 
-    private Hashable initHelper(Token left, Token right, byte depth, byte max)
+    private OnHeapNode initHelper(Token left, Token right, int depth, int max)
     {
         if (depth == max)
             // we've reached the leaves
-            return new Leaf();
+            return new OnHeapLeaf();
         Token midpoint = partitioner.midpoint(left, right);
 
         if (midpoint.equals(left) || midpoint.equals(right))
-            return new Leaf();
-
-        Hashable lchild =  initHelper(left, midpoint, inc(depth), max);
-        Hashable rchild =  initHelper(midpoint, right, inc(depth), max);
-        return new Inner(midpoint, lchild, rchild);
-    }
+            return new OnHeapLeaf();
 
-    Hashable root()
-    {
-        return root;
+        OnHeapNode leftChild = initHelper(left, midpoint, depth + 1, max);
+        OnHeapNode rightChild = initHelper(midpoint, right, depth + 1, max);
+        return new OnHeapInner(midpoint, leftChild, rightChild);
     }
 
     public IPartitioner partitioner()
@@ -233,20 +204,17 @@ public class MerkleTree implements Serializable
     public static List<TreeRange> difference(MerkleTree ltree, MerkleTree rtree)
     {
         if (!ltree.fullRange.equals(rtree.fullRange))
-            throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ")");
+            throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ')');
 
         List<TreeRange> diff = new ArrayList<>();
-        TreeDifference active = new TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte)0);
+        TreeRange active = new TreeRange(ltree.fullRange.left, ltree.fullRange.right, 0);
 
-        Hashable lnode = ltree.find(active);
-        Hashable rnode = rtree.find(active);
-        byte[] lhash = lnode.hash();
-        byte[] rhash = rnode.hash();
-        active.setSize(lnode.sizeOfRange(), rnode.sizeOfRange());
+        Node lnode = ltree.find(active);
+        Node rnode = rtree.find(active);
 
-        if (lhash != null && rhash != null && !Arrays.equals(lhash, rhash))
+        if (lnode.hashesDiffer(rnode))
         {
-            if(lnode instanceof  Leaf || rnode instanceof Leaf)
+            if (lnode instanceof Leaf || rnode instanceof Leaf)
             {
                 logger.debug("Digest mismatch detected among leaf nodes {}, {}", lnode, rnode);
                 diff.add(active);
@@ -261,14 +229,12 @@ public class MerkleTree implements Serializable
                 }
             }
         }
-        else if (lhash == null || rhash == null)
-            diff.add(active);
+
         return diff;
     }
 
     /**
-     * TODO: This function could be optimized into a depth first traversal of
-     * the two trees in parallel.
+     * TODO: This function could be optimized into a depth first traversal of the two trees in parallel.
      *
      * Takes two trees and a range for which they have hashes, but are inconsistent.
      * @return FULLY_INCONSISTENT if active is inconsistent, PARTIALLY_INCONSISTENT if only a subrange is inconsistent.
@@ -289,54 +255,39 @@ public class MerkleTree implements Serializable
             return FULLY_INCONSISTENT;
         }
 
-        TreeDifference left = new TreeDifference(active.left, midpoint, inc(active.depth));
-        TreeDifference right = new TreeDifference(midpoint, active.right, inc(active.depth));
+        TreeRange left = new TreeRange(active.left, midpoint, active.depth + 1);
+        TreeRange right = new TreeRange(midpoint, active.right, active.depth + 1);
         logger.debug("({}) Hashing sub-ranges [{}, {}] for {} divided by midpoint {}", active.depth, left, right, active, midpoint);
-        byte[] lhash, rhash;
-        Hashable lnode, rnode;
+        Node lnode, rnode;
 
         // see if we should recurse left
         lnode = ltree.find(left);
         rnode = rtree.find(left);
-        lhash = lnode.hash();
-        rhash = rnode.hash();
-        left.setSize(lnode.sizeOfRange(), rnode.sizeOfRange());
-        left.setRows(lnode.rowsInRange(), rnode.rowsInRange());
 
         int ldiff = CONSISTENT;
-        boolean lreso = lhash != null && rhash != null;
-        if (lreso && !Arrays.equals(lhash, rhash))
+        if (lnode.hashesDiffer(rnode))
         {
             logger.debug("({}) Inconsistent digest on left sub-range {}: [{}, {}]", active.depth, left, lnode, rnode);
-            if (lnode instanceof Leaf) ldiff = FULLY_INCONSISTENT;
-            else ldiff = differenceHelper(ltree, rtree, diff, left);
-        }
-        else if (!lreso)
-        {
-            logger.debug("({}) Left sub-range fully inconsistent {}", active.depth, left);
-            ldiff = FULLY_INCONSISTENT;
+
+            if (lnode instanceof Leaf)
+                ldiff = FULLY_INCONSISTENT;
+            else
+                ldiff = differenceHelper(ltree, rtree, diff, left);
         }
 
         // see if we should recurse right
         lnode = ltree.find(right);
         rnode = rtree.find(right);
-        lhash = lnode.hash();
-        rhash = rnode.hash();
-        right.setSize(lnode.sizeOfRange(), rnode.sizeOfRange());
-        right.setRows(lnode.rowsInRange(), rnode.rowsInRange());
 
         int rdiff = CONSISTENT;
-        boolean rreso = lhash != null && rhash != null;
-        if (rreso && !Arrays.equals(lhash, rhash))
+        if (lnode.hashesDiffer(rnode))
         {
             logger.debug("({}) Inconsistent digest on right sub-range {}: [{}, {}]", active.depth, right, lnode, rnode);
-            if (rnode instanceof Leaf) rdiff = FULLY_INCONSISTENT;
-            else rdiff = differenceHelper(ltree, rtree, diff, right);
-        }
-        else if (!rreso)
-        {
-            logger.debug("({}) Right sub-range fully inconsistent {}", active.depth, right);
-            rdiff = FULLY_INCONSISTENT;
+
+            if (rnode instanceof Leaf)
+                rdiff = FULLY_INCONSISTENT;
+            else
+                rdiff = differenceHelper(ltree, rtree, diff, right);
         }
 
         if (ldiff == FULLY_INCONSISTENT && rdiff == FULLY_INCONSISTENT)
@@ -367,32 +318,36 @@ public class MerkleTree implements Serializable
      */
     public TreeRange get(Token t)
     {
-        return getHelper(root, fullRange.left, fullRange.right, (byte)0, t);
+        return getHelper(root, fullRange.left, fullRange.right, t);
     }
 
-    TreeRange getHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t)
+    private TreeRange getHelper(Node node, Token pleft, Token pright, Token t)
     {
+        int depth = 0;
+
         while (true)
         {
-            if (hashable instanceof Leaf)
+            if (node instanceof Leaf)
             {
                 // we've reached a hash: wrap it up and deliver it
-                return new TreeRange(this, pleft, pright, depth, hashable);
+                return new TreeRange(this, pleft, pright, depth, node);
             }
-            // else: node.
-
-            Inner node = (Inner) hashable;
-            depth = inc(depth);
-            if (Range.contains(pleft, node.token, t))
-            { // left child contains token
-                hashable = node.lchild;
-                pright = node.token;
+
+            assert node instanceof Inner;
+            Inner inner = (Inner) node;
+
+            if (Range.contains(pleft, inner.token(), t)) // left child contains token
+            {
+                pright = inner.token();
+                node = inner.left();
             }
-            else
-            { // else: right child contains token
-                hashable = node.rchild;
-                pleft = node.token;
+            else // right child contains token
+            {
+                pleft = inner.token();
+                node = inner.right();
             }
+
+            depth++;
         }
     }
 
@@ -405,20 +360,21 @@ public class MerkleTree implements Serializable
         invalidateHelper(root, fullRange.left, t);
     }
 
-    private void invalidateHelper(Hashable hashable, Token pleft, Token t)
+    private void invalidateHelper(Node node, Token pleft, Token t)
     {
-        hashable.hash(null);
-        if (hashable instanceof Leaf)
+        node.hash(EMPTY_HASH);
+
+        if (node instanceof Leaf)
             return;
-        // else: node.
 
-        Inner node = (Inner)hashable;
-        if (Range.contains(pleft, node.token, t))
-            // left child contains token
-            invalidateHelper(node.lchild, pleft, t);
+        assert node instanceof Inner;
+        Inner inner = (Inner) node;
+        // TODO: reset computed flag on OnHeapInners
+
+        if (Range.contains(pleft, inner.token(), t))
+            invalidateHelper(inner.left(), pleft, t); // left child contains token
         else
-            // right child contains token
-            invalidateHelper(node.rchild, node.token, t);
+            invalidateHelper(inner.right(), inner.token(), t); // right child contains token
     }
 
     /**
@@ -428,67 +384,106 @@ public class MerkleTree implements Serializable
      * NB: Currently does not support wrapping ranges that do not end with
      * partitioner.getMinimumToken().
      *
-     * @return Null if any subrange of the range is invalid, or if the exact
+     * @return {@link #EMPTY_HASH} if any subrange of the range is invalid, or if the exact
      *         range cannot be calculated using this tree.
      */
+    @VisibleForTesting
     public byte[] hash(Range<Token> range)
     {
         return find(range).hash();
     }
 
     /**
-     * Find the {@link Hashable} node that matches the given {@code range}.
+     * Exceptions that stop recursion early when we are sure that no answer
+     * can be found.
+     */
+    static abstract class StopRecursion extends Exception
+    {
+        static class  TooDeep extends StopRecursion {}
+        static class BadRange extends StopRecursion {}
+    }
+
+    /**
+     * Find the {@link Node} node that matches the given {@code range}.
      *
      * @param range Range to find
-     * @return {@link Hashable} found. If nothing found, return {@link Leaf} with null hash.
+     * @return {@link Node} found. If nothing found, return {@link Leaf} with empty hash.
      */
-    private Hashable find(Range<Token> range)
+    private Node find(Range<Token> range)
+    {
+        try
+        {
+            return findHelper(root, new Range<>(fullRange.left, fullRange.right), range);
+        }
+        catch (StopRecursion e)
+        {
+            return new OnHeapLeaf();
+        }
+    }
+
+    interface Consumer<E extends Exception>
+    {
+        void accept(Node node) throws E;
+    }
+
+    @VisibleForTesting
+    <E extends Exception> boolean ifHashesRange(Range<Token> range, Consumer<E> consumer) throws E
     {
         try
         {
-            return findHelper(root, new Range<Token>(fullRange.left, fullRange.right), range);
+            Node node = findHelper(root, new Range<>(fullRange.left, fullRange.right), range);
+            boolean hasHash = !node.hasEmptyHash();
+            if (hasHash)
+                consumer.accept(node);
+            return hasHash;
         }
         catch (StopRecursion e)
         {
-            return new Leaf();
+            return false;
         }
     }
 
+    @VisibleForTesting
+    boolean hashesRange(Range<Token> range)
+    {
+        return ifHashesRange(range, n -> {});
+    }
+
     /**
      * @throws StopRecursion If no match could be found for the range.
      */
-    private Hashable findHelper(Hashable current, Range<Token> activeRange, Range<Token> find) throws StopRecursion
+    private Node findHelper(Node current, Range<Token> activeRange, Range<Token> find) throws StopRecursion
     {
         while (true)
         {
             if (current instanceof Leaf)
             {
                 if (!find.contains(activeRange))
-                    // we are not fully contained in this range!
-                    throw new StopRecursion.BadRange();
+                    throw new StopRecursion.BadRange(); // we are not fully contained in this range!
+
                 return current;
             }
-            // else: node.
 
-            Inner node = (Inner) current;
-            Range<Token> leftRange = new Range<>(activeRange.left, node.token);
-            Range<Token> rightRange = new Range<>(node.token, activeRange.right);
+            assert current instanceof Inner;
+            Inner inner = (Inner) current;
 
-            if (find.contains(activeRange))
-                // this node is fully contained in the range
-                return node.calc();
+            Range<Token> leftRange  = new Range<>(activeRange.left, inner.token());
+            Range<Token> rightRange = new Range<>(inner.token(), activeRange.right);
+
+            if (find.contains(activeRange)) // this node is fully contained in the range
+                return inner.compute();
 
             // else: one of our children contains the range
 
-            if (leftRange.contains(find))
-            { // left child contains/matches the range
-                current = node.lchild;
+            if (leftRange.contains(find)) // left child contains/matches the range
+            {
                 activeRange = leftRange;
+                current = inner.left();
             }
-            else if (rightRange.contains(find))
-            { // right child contains/matches the range
-                current = node.rchild;
+            else if (rightRange.contains(find)) // right child contains/matches the range
+            {
                 activeRange = rightRange;
+                current = inner.right();
             }
             else
             {
@@ -506,12 +501,12 @@ public class MerkleTree implements Serializable
      */
     public boolean split(Token t)
     {
-        if (!(size < maxsize))
+        if (size >= maxsize)
             return false;
 
         try
         {
-            root = splitHelper(root, fullRange.left, fullRange.right, (byte)0, t);
+            root = splitHelper(root, fullRange.left, fullRange.right, 0, t);
         }
         catch (StopRecursion.TooDeep e)
         {
@@ -520,12 +515,12 @@ public class MerkleTree implements Serializable
         return true;
     }
 
-    private Hashable splitHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t) throws StopRecursion.TooDeep
+    private OnHeapNode splitHelper(Node node, Token pleft, Token pright, int depth, Token t) throws StopRecursion.TooDeep
     {
         if (depth >= hashdepth)
             throw new StopRecursion.TooDeep();
 
-        if (hashable instanceof Leaf)
+        if (node instanceof Leaf)
         {
             Token midpoint = partitioner.midpoint(pleft, pright);
 
@@ -536,47 +531,47 @@ public class MerkleTree implements Serializable
 
             // split
             size++;
-            return new Inner(midpoint, new Leaf(), new Leaf());
+            return new OnHeapInner(midpoint, new OnHeapLeaf(), new OnHeapLeaf());
         }
         // else: node.
 
         // recurse on the matching child
-        Inner node = (Inner)hashable;
+        assert node instanceof OnHeapInner;
+        OnHeapInner inner = (OnHeapInner) node;
 
-        if (Range.contains(pleft, node.token, t))
-            // left child contains token
-            node.lchild(splitHelper(node.lchild, pleft, node.token, inc(depth), t));
-        else
-            // else: right child contains token
-            node.rchild(splitHelper(node.rchild, node.token, pright, inc(depth), t));
-        return node;
+        if (Range.contains(pleft, inner.token(), t)) // left child contains token
+            inner.left(splitHelper(inner.left(), pleft, inner.token(), depth + 1, t));
+        else // else: right child contains token
+            inner.right(splitHelper(inner.right(), inner.token(), pright, depth + 1, t));
+
+        return inner;
     }
 
     /**
      * Returns a lazy iterator of invalid TreeRanges that need to be filled
      * in order to make the given Range valid.
      */
-    public TreeRangeIterator invalids()
+    TreeRangeIterator rangeIterator()
     {
         return new TreeRangeIterator(this);
     }
 
-    public EstimatedHistogram histogramOfRowSizePerLeaf()
+    EstimatedHistogram histogramOfRowSizePerLeaf()
     {
         HistogramBuilder histbuild = new HistogramBuilder();
         for (TreeRange range : new TreeRangeIterator(this))
         {
-            histbuild.add(range.hashable.sizeOfRange);
+            histbuild.add(range.node.sizeOfRange());
         }
         return histbuild.buildWithStdevRangesAroundMean();
     }
 
-    public EstimatedHistogram histogramOfRowCountPerLeaf()
+    EstimatedHistogram histogramOfRowCountPerLeaf()
     {
         HistogramBuilder histbuild = new HistogramBuilder();
         for (TreeRange range : new TreeRangeIterator(this))
         {
-            histbuild.add(range.hashable.rowsInRange);
+            histbuild.add(range.node.partitionsInRange());
         }
         return histbuild.buildWithStdevRangesAroundMean();
     }
@@ -586,7 +581,7 @@ public class MerkleTree implements Serializable
         long count = 0;
         for (TreeRange range : new TreeRangeIterator(this))
         {
-            count += range.hashable.rowsInRange;
+            count += range.node.partitionsInRange();
         }
         return count;
     }
@@ -597,61 +592,23 @@ public class MerkleTree implements Serializable
         StringBuilder buff = new StringBuilder();
         buff.append("#<MerkleTree root=");
         root.toString(buff, 8);
-        buff.append(">");
+        buff.append('>');
         return buff.toString();
     }
 
-    public static class TreeDifference extends TreeRange
+    @Override
+    public boolean equals(Object other)
     {
-        private static final long serialVersionUID = 6363654174549968183L;
-
-        private long sizeOnLeft;
-        private long sizeOnRight;
-        private long rowsOnLeft;
-        private long rowsOnRight;
-
-        void setSize(long sizeOnLeft, long sizeOnRight)
-        {
-            this.sizeOnLeft = sizeOnLeft;
-            this.sizeOnRight = sizeOnRight;
-        }
-
-        void setRows(long rowsOnLeft, long rowsOnRight)
-        {
-            this.rowsOnLeft = rowsOnLeft;
-            this.rowsOnRight = rowsOnRight;
-        }
-
-        public long sizeOnLeft()
-        {
-            return sizeOnLeft;
-        }
-
-        public long sizeOnRight()
-        {
-            return sizeOnRight;
-        }
-
-        public long rowsOnLeft()
-        {
-            return rowsOnLeft;
-        }
-
-        public long rowsOnRight()
-        {
-            return rowsOnRight;
-        }
-
-        public TreeDifference(Token left, Token right, byte depth)
-        {
-            super(null, left, right, depth, null);
-        }
-
-        public long totalRows()
-        {
-            return rowsOnLeft + rowsOnRight;
-        }
-
+        if (!(other instanceof MerkleTree))
+            return false;
+        MerkleTree that = (MerkleTree) other;
+
+        return this.root.equals(that.root)
+            && this.fullRange.equals(that.fullRange)
+            && this.partitioner == that.partitioner
+            && this.hashdepth == that.hashdepth
+            && this.maxsize == that.maxsize
+            && this.size == that.size;
     }
 
     /**
@@ -664,28 +621,27 @@ public class MerkleTree implements Serializable
      */
     public static class TreeRange extends Range<Token>
     {
-        public static final long serialVersionUID = 1L;
         private final MerkleTree tree;
-        public final byte depth;
-        private final Hashable hashable;
+        public final int depth;
+        private final Node node;
 
-        TreeRange(MerkleTree tree, Token left, Token right, byte depth, Hashable hashable)
+        TreeRange(MerkleTree tree, Token left, Token right, int depth, Node node)
         {
             super(left, right);
             this.tree = tree;
             this.depth = depth;
-            this.hashable = hashable;
+            this.node = node;
         }
 
-        public void hash(byte[] hash)
+        TreeRange(Token left, Token right, int depth)
         {
-            assert tree != null : "Not intended for modification!";
-            hashable.hash(hash);
+            this(null, left, right, depth, null);
         }
 
-        public byte[] hash()
+        public void hash(byte[] hash)
         {
-            return hashable.hash();
+            assert tree != null : "Not intended for modification!";
+            node.hash(hash);
         }
 
         /**
@@ -694,18 +650,9 @@ public class MerkleTree implements Serializable
         public void addHash(RowHash entry)
         {
             assert tree != null : "Not intended for modification!";
-            assert hashable instanceof Leaf;
-
-            hashable.addHash(entry.hash, entry.size);
-        }
-
-        public void ensureHashInitialised()
-        {
-            assert tree != null : "Not intended for modification!";
-            assert hashable instanceof Leaf;
 
-            if (hashable.hash == null)
-                hashable.hash = EMPTY_HASH;
+            assert node instanceof OnHeapLeaf;
+            ((OnHeapLeaf) node).addHash(entry.hash, entry.size);
         }
 
         public void addAll(Iterator<RowHash> entries)
@@ -717,9 +664,7 @@ public class MerkleTree implements Serializable
         @Override
         public String toString()
         {
-            StringBuilder buff = new StringBuilder("#<TreeRange ");
-            buff.append(super.toString()).append(" depth=").append(depth);
-            return buff.append(">").toString();
+            return "#<TreeRange " + super.toString() + " depth=" + depth + '>';
         }
     }
 
@@ -740,8 +685,8 @@ public class MerkleTree implements Serializable
 
         TreeRangeIterator(MerkleTree tree)
         {
-            tovisit = new ArrayDeque<TreeRange>();
-            tovisit.add(new TreeRange(tree, tree.fullRange.left, tree.fullRange.right, (byte)0, tree.root));
+            tovisit = new ArrayDeque<>();
+            tovisit.add(new TreeRange(tree, tree.fullRange.left, tree.fullRange.right, 0, tree.root));
             this.tree = tree;
         }
 
@@ -756,7 +701,7 @@ public class MerkleTree implements Serializable
             {
                 TreeRange active = tovisit.pop();
 
-                if (active.hashable instanceof Leaf)
+                if (active.node instanceof Leaf)
                 {
                     // found a leaf invalid range
                     if (active.isWrapAround() && !tovisit.isEmpty())
@@ -765,9 +710,9 @@ public class MerkleTree implements Serializable
                     return active;
                 }
 
-                Inner node = (Inner)active.hashable;
-                TreeRange left = new TreeRange(tree, active.left, node.token, inc(active.depth), node.lchild);
-                TreeRange right = new TreeRange(tree, node.token, active.right, inc(active.depth), node.rchild);
+                Inner node = (Inner)active.node;
+                TreeRange left = new TreeRange(tree, active.left, node.token(), active.depth + 1, node.left());
+                TreeRange right = new TreeRange(tree, node.token(), active.right, active.depth + 1, node.right());
 
                 if (right.isWrapAround())
                 {
@@ -792,123 +737,357 @@ public class MerkleTree implements Serializable
     }
 
     /**
-     * An inner node in the MerkleTree. Inners can contain cached hash values, which
-     * are the binary hash of their two children.
+     * Hash value representing a row, to be used to pass hashes to the MerkleTree.
+     * The byte[] hash value should contain a digest of the key and value of the row
+     * created using a very strong hash function.
      */
-    static class Inner extends Hashable
+    public static class RowHash
     {
-        public static final long serialVersionUID = 1L;
-        static final byte IDENT = 2;
         public final Token token;
-        private Hashable lchild;
-        private Hashable rchild;
-
-        private static final InnerSerializer serializer = new InnerSerializer();
+        public final byte[] hash;
+        public final long size;
 
-        /**
-         * Constructs an Inner with the given token and children, and a null hash.
-         */
-        public Inner(Token token, Hashable lchild, Hashable rchild)
+        public RowHash(Token token, byte[] hash, long size)
         {
-            super(null);
             this.token = token;
-            this.lchild = lchild;
-            this.rchild = rchild;
+            this.hash  = hash;
+            this.size  = size;
         }
 
-        public Hashable lchild()
+        @Override
+        public String toString()
         {
-            return lchild;
+            return "#<RowHash " + token + ' ' + (hash == null ? "null" : Hex.bytesToHex(hash)) + " @ " + size + " bytes>";
         }
+    }
+
+    public void serialize(DataOutputPlus out, int version) throws IOException
+    {
+        out.writeByte(hashdepth);
+        out.writeLong(maxsize);
+        out.writeLong(size);
+        out.writeUTF(partitioner.getClass().getCanonicalName());
+        Token.serializer.serialize(fullRange.left, out, version);
+        Token.serializer.serialize(fullRange.right, out, version);
+        root.serialize(out, version);
+    }
+
+    public long serializedSize(int version)
+    {
+        long size = 1 // mt.hashdepth
+                  + sizeof(maxsize)
+                  + sizeof(this.size)
+                  + sizeof(partitioner.getClass().getCanonicalName());
+        size += Token.serializer.serializedSize(fullRange.left, version);
+        size += Token.serializer.serializedSize(fullRange.right, version);
+        size += root.serializedSize(version);
+        return size;
+    }
+
+    public static MerkleTree deserialize(DataInputPlus in, int version) throws IOException
+    {
+        return deserialize(in, DatabaseDescriptor.getOffheapMerkleTreesEnabled(), version);
+    }
+
+    public static MerkleTree deserialize(DataInputPlus in, boolean offHeapRequested, int version) throws IOException
+    {
+        int hashDepth = in.readByte();
+        long maxSize = in.readLong();
+        int innerNodeCount = Ints.checkedCast(in.readLong());
 
-        public Hashable rchild()
+        IPartitioner partitioner;
+        try
+        {
+            partitioner = FBUtilities.newPartitioner(in.readUTF());
+        }
+        catch (ConfigurationException e)
         {
-            return rchild;
+            throw new IOException(e);
         }
 
-        public void lchild(Hashable child)
+        Token left = Token.serializer.deserialize(in, partitioner, version);
+        Token right = Token.serializer.deserialize(in, partitioner, version);
+        Range<Token> fullRange = new Range<>(left, right);
+        Node root = deserializeTree(in, partitioner, innerNodeCount, offHeapRequested, version);
+        return new MerkleTree(root, partitioner, fullRange, hashDepth, maxSize, innerNodeCount);
+    }
+
+    private static boolean warnedOnce;
+    private static Node deserializeTree(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, boolean offHeapRequested, int version) throws IOException
+    {
+        boolean offHeapSupported = partitioner instanceof Murmur3Partitioner || partitioner instanceof RandomPartitioner;
+
+        if (offHeapRequested && !offHeapSupported && !warnedOnce)
         {
-            lchild = child;
+            logger.warn("Configuration requests offheap memtables, but partitioner does not support it. Ignoring.");
+            warnedOnce = true;
         }
 
-        public void rchild(Hashable child)
+        return offHeapRequested && offHeapSupported
+             ? deserializeOffHeap(in, partitioner, innerNodeCount, version)
+             : OnHeapNode.deserialize(in, partitioner, version);
+    }
+
+    /*
+     * Coordinating multiple trees from multiple replicas can get expensive.
+     * On the deserialization path, we know in advance what the tree looks like,
+     * So we can pre-size an offheap buffer and deserialize into that.
+     */
+
+    MerkleTree tryMoveOffHeap() throws IOException
+    {
+        boolean offHeapEnabled = DatabaseDescriptor.getOffheapMerkleTreesEnabled();
+        boolean offHeapSupported = partitioner instanceof Murmur3Partitioner || partitioner instanceof RandomPartitioner;
+
+        if (offHeapEnabled && !offHeapSupported && !warnedOnce)
         {
-            rchild = child;
+            logger.warn("Configuration requests offheap memtables, but partitioner does not support it. Ignoring.");
+            warnedOnce = true;
         }
 
-        Hashable calc()
+        return root instanceof OnHeapNode && offHeapEnabled && offHeapSupported ? moveOffHeap() : this;
+    }
+
+    private MerkleTree moveOffHeap() throws IOException
+    {
+        assert root instanceof OnHeapNode;
+        int bufferSize = offHeapBufferSize(Ints.checkedCast(size), partitioner);
+        logger.debug("Allocating direct buffer of size {} to move merkle tree off heap", bufferSize);
+        ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize);
+        int pointer = ((OnHeapNode) root).serializeOffHeap(buffer, partitioner);
+        OffHeapNode newRoot = fromPointer(pointer, buffer, partitioner).attachRef();
+        return new MerkleTree(newRoot, partitioner, fullRange, hashdepth, maxsize, size);
+    }
+
+    private static OffHeapNode deserializeOffHeap(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, int version) throws IOException
+    {
+        int bufferSize = offHeapBufferSize(innerNodeCount, partitioner);
+        logger.debug("Allocating direct buffer of size {} for merkle tree deserialization", bufferSize);
+        ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize);
+        int pointer = OffHeapNode.deserialize(in, buffer, partitioner, version);
+        return fromPointer(pointer, buffer, partitioner).attachRef();
+    }
+
+    private static OffHeapNode fromPointer(int pointer, ByteBuffer buffer, IPartitioner partitioner)
+    {
+        return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer);
+    }
+
+    private static int offHeapBufferSize(int innerNodeCount, IPartitioner partitioner)
+    {
+        return innerNodeCount * OffHeapInner.maxOffHeapSize(partitioner) + (innerNodeCount + 1) * OffHeapLeaf.maxOffHeapSize();
+    }
+
+    interface Node
+    {
+        byte[] hash();
+
+        boolean hasEmptyHash();
+
+        void hash(byte[] hash);
+
+        boolean hashesDiffer(Node other);
+
+        default Node compute()
         {
-            if (hash == null)
-            {
-                // hash and size haven't been calculated; calc children then compute
-                Hashable lnode = lchild.calc();
-                Hashable rnode = rchild.calc();
-                // cache the computed value
-                hash(lnode.hash, rnode.hash);
-                sizeOfRange = lnode.sizeOfRange + rnode.sizeOfRange;
-                rowsInRange = lnode.rowsInRange + rnode.rowsInRange;
-            }
             return this;
         }
 
-        /**
-         * Recursive toString.
-         */
-        public void toString(StringBuilder buff, int maxdepth)
+        default long sizeOfRange()
         {
-            buff.append("#<").append(getClass().getSimpleName());
-            buff.append(" ").append(token);
-            buff.append(" hash=").append(Hashable.toString(hash()));
-            buff.append(" children=[");
-            if (maxdepth < 1)
-            {
-                buff.append("#");
-            }
-            else
-            {
-                if (lchild == null)
-                    buff.append("null");
-                else
-                    lchild.toString(buff, maxdepth-1);
-                buff.append(" ");
-                if (rchild == null)
-                    buff.append("null");
-                else
-                    rchild.toString(buff, maxdepth-1);
-            }
-            buff.append("]>");
+            return 0;
+        }
+
+        default long partitionsInRange()
+        {
+            return 0;
+        }
+
+        void serialize(DataOutputPlus out, int version) throws IOException;
+        int serializedSize(int version);
+
+        void toString(StringBuilder buff, int maxdepth);
+
+        static String toString(byte[] hash)
+        {
+            return hash == null
+                 ? "null"
+                 : '[' + Hex.bytesToHex(hash) + ']';
+        }
+
+        boolean equals(Node node);
+    }
+
+    static abstract class OnHeapNode implements Node
+    {
+        long sizeOfRange;
+        long partitionsInRange;
+
+        protected byte[] hash;
+
+        OnHeapNode(byte[] hash)
+        {
+            if (hash == null)
+                throw new IllegalArgumentException();
+
+            this.hash = hash;
+        }
+
+        public byte[] hash()
+        {
+            return hash;
+        }
+
+        public boolean hasEmptyHash()
+        {
+            //noinspection ArrayEquality
+            return hash == EMPTY_HASH;
+        }
+
+        public void hash(byte[] hash)
+        {
+            if (hash == null)
+                throw new IllegalArgumentException();
+
+            this.hash = hash;
+        }
+
+        public boolean hashesDiffer(Node other)
+        {
+            return other instanceof OnHeapNode
+                 ? hashesDiffer( (OnHeapNode) other)
+                 : hashesDiffer((OffHeapNode) other);
+        }
+
+        private boolean hashesDiffer(OnHeapNode other)
+        {
+            return !Arrays.equals(hash(), other.hash());
+        }
+
+        private boolean hashesDiffer(OffHeapNode other)
+        {
+            return compare(hash(), other.buffer(), other.hashBytesOffset(), HASH_SIZE) != 0;
         }
 
         @Override
-        public String toString()
+        public long sizeOfRange()
         {
-            StringBuilder buff = new StringBuilder();
-            toString(buff, 1);
-            return buff.toString();
+            return sizeOfRange;
         }
 
-        private static class InnerSerializer implements IPartitionerDependentSerializer<Inner>
+        @Override
+        public long partitionsInRange()
         {
-            public void serialize(Inner inner, DataOutputPlus out, int version) throws IOException
-            {
-                Token.serializer.serialize(inner.token, out, version);
-                Hashable.serializer.serialize(inner.lchild, out, version);
-                Hashable.serializer.serialize(inner.rchild, out, version);
-            }
+            return partitionsInRange;
+        }
 
-            public Inner deserialize(DataInput in, IPartitioner p, int version) throws IOException
+        static OnHeapNode deserialize(DataInputPlus in, IPartitioner p, int version) throws IOException
+        {
+            byte ident = in.readByte();
+
+            switch (ident)
             {
-                Token token = Token.serializer.deserialize(in, p, version);
-                Hashable lchild = Hashable.serializer.deserialize(in, p, version);
-                Hashable rchild = Hashable.serializer.deserialize(in, p, version);
-                return new Inner(token, lchild, rchild);
+                case Inner.IDENT:
+                    return OnHeapInner.deserializeWithoutIdent(in, p, version);
+                case Leaf.IDENT:
+                    return OnHeapLeaf.deserializeWithoutIdent(in);
+                default:
+                    throw new IOException("Unexpected node type: " + ident);
             }
+        }
+
+        abstract int serializeOffHeap(ByteBuffer buffer, IPartitioner p) throws IOException;
+    }
+
+    static abstract class OffHeapNode implements Node
+    {
+        protected final ByteBuffer buffer;
+        protected final int offset;
+
+        OffHeapNode(ByteBuffer buffer, int offset)
+        {
+            this.buffer = buffer;
+            this.offset = offset;
+        }
+
+        ByteBuffer buffer()
+        {
+            return buffer;
+        }
+
+        public byte[] hash()
+        {
+            final int position = buffer.position();
+            buffer.position(hashBytesOffset());
+            byte[] array = new byte[HASH_SIZE];
+            buffer.get(array);
+            buffer.position(position);
+            return array;
+        }
+
+        public boolean hasEmptyHash()
+        {
+            return compare(buffer(), hashBytesOffset(), HASH_SIZE, EMPTY_HASH) == 0;
+        }
+
+        public void hash(byte[] hash)
+        {
+            throw new UnsupportedOperationException();
+        }
 
-            public long serializedSize(Inner inner, int version)
+        public boolean hashesDiffer(Node other)
+        {
+            return other instanceof OnHeapNode
+                 ? hashesDiffer((OnHeapNode) other)
+                 : hashesDiffer((OffHeapNode) other);
+        }
+
+        private boolean hashesDiffer(OnHeapNode other)
+        {
+            return compare(buffer(), hashBytesOffset(), HASH_SIZE, other.hash()) != 0;
+        }
+
+        private boolean hashesDiffer(OffHeapNode other)
+        {
+            int thisOffset = hashBytesOffset();
+            int otherOffset = other.hashBytesOffset();
+
+            for (int i = 0; i < HASH_SIZE; i += 8)
+                if (buffer().getLong(thisOffset + i) != other.buffer().getLong(otherOffset + i))
+                    return true;
+
+            return false;
+        }
+
+        OffHeapNode attachRef()
+        {
+            if (Ref.DEBUG_ENABLED)
+                MemoryUtil.setAttachment(buffer, new Ref<>(this, null));
+            return this;
+        }
+
+        void release()
+        {
+            Object attachment = MemoryUtil.getAttachment(buffer);
+            if (attachment instanceof Ref)
+                ((Ref) attachment).release();
+            FileUtils.clean(buffer);
+        }
+
+        abstract int hashBytesOffset();
+
+        static int deserialize(DataInputPlus in, ByteBuffer buffer, IPartitioner p, int version) throws IOException
+        {
+            byte ident = in.readByte();
+
+            switch (ident)
             {
-                return Token.serializer.serializedSize(inner.token, version)
-                     + Hashable.serializer.serializedSize(inner.lchild, version)
-                     + Hashable.serializer.serializedSize(inner.rchild, version);
+                case Inner.IDENT:
+                    return OffHeapInner.deserializeWithoutIdent(in, buffer, p, version);
+                case Leaf.IDENT:
+                    return  OffHeapLeaf.deserializeWithoutIdent(in, buffer);
+                default:
+                    throw new IOException("Unexpected node type: " + ident);
             }
         }
     }
@@ -922,237 +1101,440 @@ public class MerkleTree implements Serializable
      * tree extending below the Leaf is generated in memory, but only the root
      * is stored in the Leaf.
      */
-    static class Leaf extends Hashable
+    interface Leaf extends Node
     {
-        public static final long serialVersionUID = 1L;
         static final byte IDENT = 1;
-        private static final LeafSerializer serializer = new LeafSerializer();
+
+        default void serialize(DataOutputPlus out, int version) throws IOException
+        {
+            byte[] hash = hash();
+            assert hash.length == HASH_SIZE;
+
+            out.writeByte(Leaf.IDENT);
+
+            if (!hasEmptyHash())
+            {
+                out.writeByte(HASH_SIZE);
+                out.write(hash);
+            }
+            else
+            {
+                out.writeByte(0);
+            }
+        }
+
+        default int serializedSize(int version)
+        {
+            return 2 + (hasEmptyHash() ? 0 : HASH_SIZE);
+        }
+
+        default void toString(StringBuilder buff, int maxdepth)
+        {
+            buff.append(toString());
+        }
+
+        default boolean equals(Node other)
+        {
+            return other instanceof Leaf && !hashesDiffer(other);
+        }
+    }
+
+    static class OnHeapLeaf extends OnHeapNode implements Leaf
+    {
+        OnHeapLeaf()
+        {
+            super(EMPTY_HASH);
+        }
+
+        OnHeapLeaf(byte[] hash)
+        {
+            super(hash);
+        }
 
         /**
-         * Constructs a null hash.
+         * Mixes the given value into our hash. If our hash is null,
+         * our hash will become the given value.
          */
-        public Leaf()
+        void addHash(byte[] partitionHash, long partitionSize)
         {
-            super(null);
+            if (hasEmptyHash())
+                hash(partitionHash);
+            else
+                FBUtilities.xorOntoLeft(hash, partitionHash);
+
+            sizeOfRange += partitionSize;
+            partitionsInRange += 1;
         }
 
-        public Leaf(byte[] hash)
+        static OnHeapLeaf deserializeWithoutIdent(DataInputPlus in) throws IOException
         {
-            super(hash);
+            if (in.readByte() > 0)
+            {
+                byte[] hash = new byte[HASH_SIZE];
+                in.readFully(hash);
+                return new OnHeapLeaf(hash);
+            }
+            else
+            {
+                return new OnHeapLeaf();
+            }
         }
 
-        public void toString(StringBuilder buff, int maxdepth)
+        int serializeOffHeap(ByteBuffer buffer, IPartitioner p)
         {
-            buff.append(toString());
+            if (buffer.remaining() < OffHeapLeaf.maxOffHeapSize())
+                throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap");
+
+            if (hash.length != HASH_SIZE)
+                throw new IllegalArgumentException("Hash of unexpected size when serializing a Leaf off-heap: " + hash.length);
+
+            final int position = buffer.position();
+            buffer.put(hash);
+            return ~position;
         }
 
         @Override
         public String toString()
         {
-            return "#<Leaf " + Hashable.toString(hash()) + ">";
+            return "#<OnHeapLeaf " + Node.toString(hash()) + '>';
         }
+    }
+
+    static class OffHeapLeaf extends OffHeapNode implements Leaf
+    {
+        static final int HASH_BYTES_OFFSET = 0;
 
-        private static class LeafSerializer implements IPartitionerDependentSerializer<Leaf>
+        OffHeapLeaf(ByteBuffer buffer, int offset)
         {
-            public void serialize(Leaf leaf, DataOutputPlus out, int version) throws IOException
-            {
-                if (leaf.hash == null)
-                {
-                    out.writeByte(-1);
-                }
-                else
-                {
-                    out.writeByte(leaf.hash.length);
-                    out.write(leaf.hash);
-                }
-            }
+            super(buffer, offset);
+        }
+
+        public int hashBytesOffset()
+        {
+            return offset + HASH_BYTES_OFFSET;
+        }
+
+        static int deserializeWithoutIdent(DataInput in, ByteBuffer buffer) throws IOException
+        {
+            if (buffer.remaining() < maxOffHeapSize())
+                throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap");
 
-            public Leaf deserialize(DataInput in, IPartitioner p, int version) throws IOException
+            final int position = buffer.position();
+
+            int hashLength = in.readByte();
+            if (hashLength > 0)
             {
-                int hashLen = in.readByte();
-                byte[] hash = hashLen < 0 ? null : new byte[hashLen];
-                if (hash != null)
-                    in.readFully(hash);
-                return new Leaf(hash);
-            }
+                if (hashLength != HASH_SIZE)
+                    throw new IllegalStateException("Hash of unexpected size when deserializing an off-heap Leaf node: " + hashLength);
 
-            public long serializedSize(Leaf leaf, int version)
+                byte[] hashBytes = getTempArray(HASH_SIZE);
+                in.readFully(hashBytes, 0, HASH_SIZE);
+                buffer.put(hashBytes, 0, HASH_SIZE);
+            }
+            else
             {
-                long size = 1;
-                if (leaf.hash != null)
-                    size += leaf.hash().length;
-                return size;
+                buffer.put(EMPTY_HASH, 0, HASH_SIZE);
             }
+
+            return ~position;
         }
-    }
 
-    /**
-     * Hash value representing a row, to be used to pass hashes to the MerkleTree.
-     * The byte[] hash value should contain a digest of the key and value of the row
-     * created using a very strong hash function.
-     */
-    public static class RowHash
-    {
-        public final Token token;
-        public final byte[] hash;
-        public final long size;
-        public RowHash(Token token, byte[] hash, long size)
+        static int maxOffHeapSize()
         {
-            this.token = token;
-            this.hash  = hash;
-            this.size = size;
+            return HASH_SIZE;
         }
 
         @Override
         public String toString()
         {
-            return "#<RowHash " + token + " " + Hashable.toString(hash) + " @ " + size + " bytes>";
+            return "#<OffHeapLeaf " + Node.toString(hash()) + '>';
         }
     }
 
     /**
-     * Abstract class containing hashing logic, and containing a single hash field.
+     * An inner node in the MerkleTree. Inners can contain cached hash values, which
+     * are the binary hash of their two children.
      */
-    static abstract class Hashable implements Serializable
+    interface Inner extends Node
     {
-        private static final long serialVersionUID = 1L;
-        private static final IPartitionerDependentSerializer<Hashable> serializer = new HashableSerializer();
+        static final byte IDENT = 2;
 
-        protected byte[] hash;
-        protected long sizeOfRange;
-        protected long rowsInRange;
+        public Token token();
 
-        protected Hashable(byte[] hash)
+        public Node left();
+        public Node right();
+
+        default void serialize(DataOutputPlus out, int version) throws IOException
         {
-            this.hash = hash;
+            out.writeByte(Inner.IDENT);
+            Token.serializer.serialize(token(), out, version);
+            left().serialize(out, version);
+            right().serialize(out, version);
         }
 
-        public byte[] hash()
+        default int serializedSize(int version)
         {
-            return hash;
+            return 1
+                 + (int) Token.serializer.serializedSize(token(), version)
+                 + left().serializedSize(version)
+                 + right().serializedSize(version);
         }
 
-        public long sizeOfRange()
+        default void toString(StringBuilder buff, int maxdepth)
         {
-            return sizeOfRange;
+            buff.append("#<").append(getClass().getSimpleName())
+                .append(' ').append(token())
+                .append(" hash=").append(Node.toString(hash()))
+                .append(" children=[");
+
+            if (maxdepth < 1)
+            {
+                buff.append('#');
+            }
+            else
+            {
+                Node left = left();
+                if (left == null)
+                    buff.append("null");
+                else
+                    left.toString(buff, maxdepth - 1);
+
+                buff.append(' ');
+
+                Node right = right();
+                if (right == null)
+                    buff.append("null");
+                else
+                    right.toString(buff, maxdepth - 1);
+            }
+
+            buff.append("]>");
         }
 
-        public long rowsInRange()
+        default boolean equals(Node other)
         {
-            return rowsInRange;
+            if (!(other instanceof Inner))
+                return false;
+            Inner that = (Inner) other;
+            return !hashesDiffer(other) && this.left().equals(that.left()) && this.right().equals(that.right());
         }
+    }
+
+    static class OnHeapInner extends OnHeapNode implements Inner
+    {
+        private final Token token;
 
-        void hash(byte[] hash)
+        private OnHeapNode left;
+        private OnHeapNode right;
+
+        private boolean computed;
+
+        OnHeapInner(Token token, OnHeapNode left, OnHeapNode right)
         {
-            this.hash = hash;
+            super(EMPTY_HASH);
+
+            this.token = token;
+            this.left = left;
+            this.right = right;
         }
 
-        Hashable calc()
+        public Token token()
         {
-            return this;
+            return token;
         }
 
-        /**
-         * Sets the value of this hash to binaryHash of its children.
-         * @param lefthash Hash of left child.
-         * @param righthash Hash of right child.
-         */
-        void hash(byte[] lefthash, byte[] righthash)
+        public OnHeapNode left()
         {
-            hash = binaryHash(lefthash, righthash);
+            return left;
         }
 
-        /**
-         * Mixes the given value into our hash. If our hash is null,
-         * our hash will become the given value.
-         */
-        void addHash(byte[] righthash, long sizeOfRow)
+        public OnHeapNode right()
         {
-            if (hash == null)
-                hash = righthash;
-            else
-                hash = binaryHash(hash, righthash);
-            this.sizeOfRange += sizeOfRow;
-            this.rowsInRange += 1;
+            return right;
         }
 
-        /**
-         * The primitive with which all hashing should be accomplished: hashes
-         * a left and right value together.
-         */
-        static byte[] binaryHash(final byte[] left, final byte[] right)
+        void left(OnHeapNode child)
         {
-            return FBUtilities.xor(left, right);
+            left = child;
         }
 
-        public abstract void toString(StringBuilder buff, int maxdepth);
-
-        public static String toString(byte[] hash)
+        void right(OnHeapNode child)
         {
-            if (hash == null)
-                return "null";
-            return "[" + Hex.bytesToHex(hash) + "]";
+            right = child;
         }
 
-        private static class HashableSerializer implements IPartitionerDependentSerializer<Hashable>
+        @Override
+        public Node compute()
         {
-            public void serialize(Hashable h, DataOutputPlus out, int version) throws IOException
+            if (!computed) // hash and size haven't been calculated; compute children then compute this
             {
-                if (h instanceof Inner)
-                {
-                    out.writeByte(Inner.IDENT);
-                    Inner.serializer.serialize((Inner)h, out, version);
-                }
-                else if (h instanceof Leaf)
-                {
-                    out.writeByte(Leaf.IDENT);
-                    Leaf.serializer.serialize((Leaf) h, out, version);
-                }
-                else
-                    throw new IOException("Unexpected Hashable: " + h.getClass().getCanonicalName());
-            }
+                left.compute();
+                right.compute();
 
-            public Hashable deserialize(DataInput in, IPartitioner p, int version) throws IOException
-            {
-                byte ident = in.readByte();
-                if (Inner.IDENT == ident)
-                    return Inner.serializer.deserialize(in, p, version);
-                else if (Leaf.IDENT == ident)
-                    return Leaf.serializer.deserialize(in, p, version);
-                else
-                    throw new IOException("Unexpected Hashable: " + ident);
+                if (!left.hasEmptyHash() && !right.hasEmptyHash())
+                    hash = FBUtilities.xor(left.hash(), right.hash());
+                else if (left.hasEmptyHash())
+                    hash = right.hash();
+                else if (right.hasEmptyHash())
+                    hash = left.hash();
+
+                sizeOfRange       = left.sizeOfRange()       + right.sizeOfRange();
+                partitionsInRange = left.partitionsInRange() + right.partitionsInRange();
+
+                computed = true;
             }
 
-            public long serializedSize(Hashable h, int version)
+            return this;
+        }
+
+        static OnHeapInner deserializeWithoutIdent(DataInputPlus in, IPartitioner p, int version) throws IOException
+        {
+            Token token = Token.serializer.deserialize(in, p, version);
+            OnHeapNode  left = OnHeapNode.deserialize(in, p, version);
+            OnHeapNode right = OnHeapNode.deserialize(in, p, version);
+            return new OnHeapInner(token, left, right);
+        }
+
+        int serializeOffHeap(ByteBuffer buffer, IPartitioner partitioner) throws IOException
+        {
+            if (buffer.remaining() < OffHeapInner.maxOffHeapSize(partitioner))
+                throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap");
+
+            final int offset = buffer.position();
+
+            int tokenSize = partitioner.getTokenFactory().byteSize(token);
+            buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize));
+            buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET);
+            partitioner.getTokenFactory().serialize(token, buffer);
+
+            int  leftPointer =  left.serializeOffHeap(buffer, partitioner);
+            int rightPointer = right.serializeOffHeap(buffer, partitioner);
+
+            buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET,  leftPointer);
+            buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer);
+
+            int  leftHashOffset = OffHeapInner.hashBytesOffset(leftPointer);
+            int rightHashOffset = OffHeapInner.hashBytesOffset(rightPointer);
+
+            for (int i = 0; i < HASH_SIZE; i += 8)
             {
-                if (h instanceof Inner)
-                    return 1 + Inner.serializer.serializedSize((Inner) h, version);
-                else if (h instanceof Leaf)
-                    return 1 + Leaf.serializer.serializedSize((Leaf) h, version);
-                throw new AssertionError(h.getClass());
+                buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i,
+                               buffer.getLong(leftHashOffset  + i) ^ buffer.getLong(rightHashOffset + i));
             }
+
+            return offset;
+        }
+
+        @Override
+        public String toString()
+        {
+            StringBuilder buff = new StringBuilder();
+            toString(buff, 1);
+            return buff.toString();
         }
     }
 
-    /**
-     * Exceptions that stop recursion early when we are sure that no answer
-     * can be found.
-     */
-    static abstract class StopRecursion extends Exception
+    static class OffHeapInner extends OffHeapNode implements Inner
     {
-        static class BadRange extends StopRecursion
+        /**
+         * All we want to keep here is just a pointer to the start of the Inner leaf in the
+         * direct buffer. From there, we'll be able to deserialize the following, in this order:
+         *
+         * 1. pointer to left child (int)
+         * 2. pointer to right child (int)
+         * 3. hash bytes (space allocated as HASH_MAX_SIZE)
+         * 4. token length (short)
+         * 5. token bytes (variable length)
+         */
+        static final int LEFT_CHILD_POINTER_OFFSET  = 0;
+        static final int RIGHT_CHILD_POINTER_OFFSET = 4;
+        static final int HASH_BYTES_OFFSET          = 8;
+        static final int TOKEN_LENGTH_OFFSET        = 8 + HASH_SIZE;
+        static final int TOKEN_BYTES_OFFSET         = TOKEN_LENGTH_OFFSET + 2;
+
+        private final IPartitioner partitioner;
+
+        OffHeapInner(ByteBuffer buffer, int offset, IPartitioner partitioner)
         {
-            public BadRange(){ super(); }
+            super(buffer, offset);
+            this.partitioner = partitioner;
         }
 
-        static class InvalidHash extends StopRecursion
+        public Token token()
         {
-            public InvalidHash(){ super(); }
+            int length = buffer.getShort(offset + TOKEN_LENGTH_OFFSET);
+            return partitioner.getTokenFactory().fromByteBuffer(buffer, offset + TOKEN_BYTES_OFFSET, length);
         }
 
-        static class TooDeep extends StopRecursion
+        public Node left()
         {
-            public TooDeep(){ super(); }
+            int pointer = buffer.getInt(offset + LEFT_CHILD_POINTER_OFFSET);
+            return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer);
+        }
+
+        public Node right()
+        {
+            int pointer = buffer.getInt(offset + RIGHT_CHILD_POINTER_OFFSET);
+            return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer);
+        }
+
+        public int hashBytesOffset()
+        {
+            return offset + HASH_BYTES_OFFSET;
+        }
+
+        static int deserializeWithoutIdent(DataInputPlus in, ByteBuffer buffer, IPartitioner partitioner, int version) throws IOException
+        {
+            if (buffer.remaining() < maxOffHeapSize(partitioner))
+                throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap");
+
+            final int offset = buffer.position();
+
+            int tokenSize = Token.serializer.deserializeSize(in);
+            byte[] tokenBytes = getTempArray(tokenSize);
+            in.readFully(tokenBytes, 0, tokenSize);
+
+            buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize));
+            buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET);
+            buffer.put(tokenBytes, 0, tokenSize);
+
+            int leftPointer  = OffHeapNode.deserialize(in, buffer, partitioner, version);
+            int rightPointer = OffHeapNode.deserialize(in, buffer, partitioner, version);
+
+            buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET,  leftPointer);
+            buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer);
+
+            int leftHashOffset  = hashBytesOffset(leftPointer);
+            int rightHashOffset = hashBytesOffset(rightPointer);
+
+            for (int i = 0; i < HASH_SIZE; i += 8)
+            {
+                buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i,
+                               buffer.getLong(leftHashOffset  + i) ^ buffer.getLong(rightHashOffset + i));
+            }
+
+            return offset;
+        }
+
+        static int maxOffHeapSize(IPartitioner partitioner)
+        {
+            return 4 // left pointer
+                 + 4 // right pointer
+                 + HASH_SIZE
+                 + 2 + partitioner.getMaxTokenSize();
+        }
+
+        static int hashBytesOffset(int pointer)
+        {
+            return pointer >= 0 ? pointer + OffHeapInner.HASH_BYTES_OFFSET : ~pointer + OffHeapLeaf.HASH_BYTES_OFFSET;
+        }
+
+        @Override
+        public String toString()
+        {
+            StringBuilder buff = new StringBuilder();
+            toString(buff, 1);
+            return buff.toString();
         }
     }
 
@@ -1183,10 +1565,10 @@ public class MerkleTree implements Serializable
     {
         byte[] hashLeft = new byte[bytesPerHash];
         byte[] hashRigth = new byte[bytesPerHash];
-        Leaf left = new Leaf(hashLeft);
-        Leaf right = new Leaf(hashRigth);
-        Inner inner = new Inner(partitioner.getMinimumToken(), left, right);
-        inner.calc();
+        OnHeapLeaf left = new OnHeapLeaf(hashLeft);
+        OnHeapLeaf right = new OnHeapLeaf(hashRigth);
+        Inner inner = new OnHeapInner(partitioner.getMinimumToken(), left, right);
+        inner.compute();
 
         // Some partioners have variable token sizes, try to estimate as close as we can by using the same
         // heap estimate as the memtables use.
diff --git a/src/java/org/apache/cassandra/utils/MerkleTrees.java b/src/java/org/apache/cassandra/utils/MerkleTrees.java
index d2a8058..9cf04c5 100644
--- a/src/java/org/apache/cassandra/utils/MerkleTrees.java
+++ b/src/java/org/apache/cassandra/utils/MerkleTrees.java
@@ -44,9 +44,9 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
 {
     public static final MerkleTreesSerializer serializer = new MerkleTreesSerializer();
 
-    private Map<Range<Token>, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator());
+    private final Map<Range<Token>, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator());
 
-    private IPartitioner partitioner;
+    private final IPartitioner partitioner;
 
     /**
      * Creates empty MerkleTrees object.
@@ -143,6 +143,14 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
     }
 
     /**
+     * Dereference all merkle trees and release direct memory for all off-heap trees.
+     */
+    public void release()
+    {
+        merkleTrees.values().forEach(MerkleTree::release); merkleTrees.clear();
+    }
+
+    /**
      * Init a selected MerkleTree with an even tree distribution.
      *
      * @param range
@@ -247,11 +255,11 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
     }
 
     /**
-     * Get an iterator for all the invalids generated by the MerkleTrees.
+     * Get an iterator for all the iterator generated by the MerkleTrees.
      *
      * @return
      */
-    public TreeRangeIterator invalids()
+    public TreeRangeIterator rangeIterator()
     {
         return new TreeRangeIterator();
     }
@@ -285,30 +293,20 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
     @VisibleForTesting
     public byte[] hash(Range<Token> range)
     {
-        ByteArrayOutputStream baos = new ByteArrayOutputStream();
-        boolean hashed = false;
-
-        try
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream())
         {
-            for (Range<Token> rt : merkleTrees.keySet())
-            {
-                if (rt.intersects(range))
-                {
-                    byte[] bytes = merkleTrees.get(rt).hash(range);
-                    if (bytes != null)
-                    {
-                        baos.write(bytes);
-                        hashed = true;
-                    }
-                }
-            }
+            boolean hashed = false;
+
+            for (Map.Entry<Range<Token>, MerkleTree> entry : merkleTrees.entrySet())
+                if (entry.getKey().intersects(range))
+                    hashed |= entry.getValue().ifHashesRange(range, n -> baos.write(n.hash()));
+
+            return hashed ? baos.toByteArray() : null;
         }
         catch (IOException e)
         {
             throw new RuntimeException("Unable to append merkle tree hash to result");
         }
-
-        return hashed ? baos.toByteArray() : null;
     }
 
     /**
@@ -354,7 +352,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
         {
             if (it.hasNext())
             {
-                current = it.next().invalids();
+                current = it.next().rangeIterator();
 
                 return current.next();
             }
@@ -369,6 +367,17 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
     }
 
     /**
+     * @return a new {@link MerkleTrees} instance with all trees moved off heap.
+     */
+    public MerkleTrees tryMoveOffHeap() throws IOException
+    {
+        Map<Range<Token>, MerkleTree> movedTrees = new TreeMap<>(new TokenRangeComparator());
+        for (Map.Entry<Range<Token>, MerkleTree> entry : merkleTrees.entrySet())
+            movedTrees.put(entry.getKey(), entry.getValue().tryMoveOffHeap());
+        return new MerkleTrees(partitioner, movedTrees.values());
+    }
+
+    /**
      * Get the differences between the two sets of MerkleTrees.
      *
      * @param ltree
@@ -379,9 +388,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
     {
         List<Range<Token>> differences = new ArrayList<>();
         for (MerkleTree tree : ltree.merkleTrees.values())
-        {
             differences.addAll(MerkleTree.difference(tree, rtree.getMerkleTree(tree.fullRange)));
-        }
         return differences;
     }
 
@@ -392,7 +399,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
             out.writeInt(trees.merkleTrees.size());
             for (MerkleTree tree : trees.merkleTrees.values())
             {
-                MerkleTree.serializer.serialize(tree, out, version);
+                tree.serialize(out, version);
             }
         }
 
@@ -405,7 +412,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
             {
                 for (int i = 0; i < nTrees; i++)
                 {
-                    MerkleTree tree = MerkleTree.serializer.deserialize(in, version);
+                    MerkleTree tree = MerkleTree.deserialize(in, version);
                     trees.add(tree);
 
                     if (partitioner == null)
@@ -425,7 +432,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree>
             long size = TypeSizes.sizeof(trees.merkleTrees.size());
             for (MerkleTree tree : trees.merkleTrees.values())
             {
-                size += MerkleTree.serializer.serializedSize(tree, version);
+                size += tree.serializedSize(version);
             }
             return size;
         }
diff --git a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java
index e787595..443d59e 100644
--- a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java
+++ b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java
@@ -238,10 +238,6 @@ public class LocalSyncTaskTest extends AbstractRepairTest
         MerkleTrees tree = new MerkleTrees(partitioner);
         tree.addMerkleTrees((int) Math.pow(2, 15), desc.ranges);
         tree.init();
-        for (MerkleTree.TreeRange r : tree.invalids())
-        {
-            r.ensureHashInitialised();
-        }
         return tree;
     }
 
diff --git a/test/unit/org/apache/cassandra/repair/RepairJobTest.java b/test/unit/org/apache/cassandra/repair/RepairJobTest.java
index 78fa588..b84adaa 100644
--- a/test/unit/org/apache/cassandra/repair/RepairJobTest.java
+++ b/test/unit/org/apache/cassandra/repair/RepairJobTest.java
@@ -774,10 +774,6 @@ public class RepairJobTest
         MerkleTrees tree = new MerkleTrees(MURMUR3_PARTITIONER);
         tree.addMerkleTrees((int) Math.pow(2, 15), fullRange);
         tree.init();
-        for (MerkleTree.TreeRange r : tree.invalids())
-        {
-            r.ensureHashInitialised();
-        }
 
         if (invalidate)
         {
diff --git a/test/unit/org/apache/cassandra/repair/ValidatorTest.java b/test/unit/org/apache/cassandra/repair/ValidatorTest.java
index aec2612..9e848a9 100644
--- a/test/unit/org/apache/cassandra/repair/ValidatorTest.java
+++ b/test/unit/org/apache/cassandra/repair/ValidatorTest.java
@@ -203,12 +203,14 @@ public class ValidatorTest
                                                cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(),
                                                                                                                 sstable.last.getToken())));
 
-        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(),
+        InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2");
+
+        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host,
                                                                  Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE,
                                                                  false, PreviewKind.NONE);
 
         final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink();
-        Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE);
+        Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE);
         ValidationManager.instance.submitValidation(cfs, validator);
 
         Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS);
@@ -260,12 +262,14 @@ public class ValidatorTest
                                                      cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(),
                                                                                                                sstable.last.getToken())));
 
-        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(),
+        InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2");
+
+        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host,
                                                                  Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE,
                                                                  false, PreviewKind.NONE);
 
         final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink();
-        Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE);
+        Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE);
         ValidationManager.instance.submitValidation(cfs, validator);
 
         Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS);
@@ -320,12 +324,14 @@ public class ValidatorTest
         final RepairJobDesc desc = new RepairJobDesc(repairSessionId, UUIDGen.getTimeUUID(), cfs.keyspace.getName(),
                                                      cfs.getTableName(), ranges);
 
-        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(),
+        InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2");
+
+        ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host,
                                                                  Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE,
                                                                  false, PreviewKind.NONE);
 
         final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink();
-        Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE);
+        Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE);
         ValidationManager.instance.submitValidation(cfs, validator);
 
         Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS);
diff --git a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java
index 9693010..de69cd7 100644
--- a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java
+++ b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java
@@ -74,9 +74,9 @@ public class DifferenceHolderTest
         mt1.init();
         mt2.init();
         // add dummy hashes to both trees
-        for (MerkleTree.TreeRange range : mt1.invalids())
+        for (MerkleTree.TreeRange range : mt1.rangeIterator())
             range.addAll(new MerkleTreesTest.HIterator(range.right));
-        for (MerkleTree.TreeRange range : mt2.invalids())
+        for (MerkleTree.TreeRange range : mt2.rangeIterator())
             range.addAll(new MerkleTreesTest.HIterator(range.right));
 
         MerkleTree.TreeRange leftmost = null;
@@ -85,7 +85,7 @@ public class DifferenceHolderTest
         mt1.maxsize(fullRange, maxsize + 2); // give some room for splitting
 
         // split the leftmost
-        Iterator<MerkleTree.TreeRange> ranges = mt1.invalids();
+        Iterator<MerkleTree.TreeRange> ranges = mt1.rangeIterator();
         leftmost = ranges.next();
         mt1.split(leftmost.right);
 
diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java
index c213271..d8491d0 100644
--- a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java
+++ b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java
@@ -1,21 +1,20 @@
 /*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyten ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*    http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.cassandra.utils;
 
 import java.math.BigInteger;
@@ -36,10 +35,8 @@ import org.apache.cassandra.dht.RandomPartitioner.BigIntegerToken;
 import org.apache.cassandra.dht.Range;
 import org.apache.cassandra.dht.Token;
 import org.apache.cassandra.io.util.DataInputBuffer;
-import org.apache.cassandra.io.util.DataInputPlus;
 import org.apache.cassandra.io.util.DataOutputBuffer;
 import org.apache.cassandra.net.MessagingService;
-import org.apache.cassandra.utils.MerkleTree.Hashable;
 import org.apache.cassandra.utils.MerkleTree.RowHash;
 import org.apache.cassandra.utils.MerkleTree.TreeRange;
 import org.apache.cassandra.utils.MerkleTree.TreeRangeIterator;
@@ -49,7 +46,7 @@ import static org.junit.Assert.*;
 
 public class MerkleTreeTest
 {
-    public static byte[] DUMMY = "blah".getBytes();
+    private static final byte[] DUMMY = HashingUtils.newMessageDigest("SHA-256").digest("dummy".getBytes());
 
     /**
      * If a test assumes that the tree is 8 units wide, then it should set this value
@@ -68,6 +65,9 @@ public class MerkleTreeTest
     @Before
     public void setup()
     {
+        DatabaseDescriptor.clientInitialization();
+        DatabaseDescriptor.setOffheapMerkleTreesEnabled(false);
+
         TOKEN_SCALE = new BigInteger("8");
         partitioner = RandomPartitioner.instance;
         // TODO need to trickle TokenSerializer
@@ -171,7 +171,7 @@ public class MerkleTreeTest
         Iterator<TreeRange> ranges;
 
         // (zero, zero]
-        ranges = mt.invalids();
+        ranges = mt.rangeIterator();
         assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next());
         assertFalse(ranges.hasNext());
 
@@ -181,7 +181,7 @@ public class MerkleTreeTest
         mt.split(tok(6));
         mt.split(tok(3));
         mt.split(tok(5));
-        ranges = mt.invalids();
+        ranges = mt.rangeIterator();
         assertEquals(new Range<>(tok(6), tok(-1)), ranges.next());
         assertEquals(new Range<>(tok(-1), tok(2)), ranges.next());
         assertEquals(new Range<>(tok(2), tok(3)), ranges.next());
@@ -200,7 +200,7 @@ public class MerkleTreeTest
         Range<Token> range = new Range<>(tok(-1), tok(-1));
 
         // (zero, zero]
-        assertNull(mt.hash(range));
+        assertFalse(mt.hashesRange(range));
 
         // validate the range
         mt.get(tok(-1)).hash(val);
@@ -223,11 +223,12 @@ public class MerkleTreeTest
         // (zero,two] (two,four] (four, zero]
         mt.split(tok(4));
         mt.split(tok(2));
-        assertNull(mt.hash(left));
-        assertNull(mt.hash(partial));
-        assertNull(mt.hash(right));
-        assertNull(mt.hash(linvalid));
-        assertNull(mt.hash(rinvalid));
+
+        assertFalse(mt.hashesRange(left));
+        assertFalse(mt.hashesRange(partial));
+        assertFalse(mt.hashesRange(right));
+        assertFalse(mt.hashesRange(linvalid));
+        assertFalse(mt.hashesRange(rinvalid));
 
         // validate the range
         mt.get(tok(2)).hash(val);
@@ -237,8 +238,8 @@ public class MerkleTreeTest
         assertHashEquals(leftval, mt.hash(left));
         assertHashEquals(partialval, mt.hash(partial));
         assertHashEquals(val, mt.hash(right));
-        assertNull(mt.hash(linvalid));
-        assertNull(mt.hash(rinvalid));
+        assertFalse(mt.hashesRange(linvalid));
+        assertFalse(mt.hashesRange(rinvalid));
     }
 
     @Test
@@ -258,10 +259,6 @@ public class MerkleTreeTest
         mt.split(tok(2));
         mt.split(tok(6));
         mt.split(tok(1));
-        assertNull(mt.hash(full));
-        assertNull(mt.hash(lchild));
-        assertNull(mt.hash(rchild));
-        assertNull(mt.hash(invalid));
 
         // validate the range
         mt.get(tok(1)).hash(val);
@@ -270,10 +267,14 @@ public class MerkleTreeTest
         mt.get(tok(6)).hash(val);
         mt.get(tok(-1)).hash(val);
 
+        assertTrue(mt.hashesRange(full));
+        assertTrue(mt.hashesRange(lchild));
+        assertTrue(mt.hashesRange(rchild));
+        assertFalse(mt.hashesRange(invalid));
+
         assertHashEquals(fullval, mt.hash(full));
         assertHashEquals(lchildval, mt.hash(lchild));
         assertHashEquals(rchildval, mt.hash(rchild));
-        assertNull(mt.hash(invalid));
     }
 
     @Test
@@ -294,9 +295,6 @@ public class MerkleTreeTest
         mt.split(tok(4));
         mt.split(tok(2));
         mt.split(tok(1));
-        assertNull(mt.hash(full));
-        assertNull(mt.hash(childfull));
-        assertNull(mt.hash(invalid));
 
         // validate the range
         mt.get(tok(1)).hash(val);
@@ -306,9 +304,12 @@ public class MerkleTreeTest
         mt.get(tok(16)).hash(val);
         mt.get(tok(-1)).hash(val);
 
+        assertTrue(mt.hashesRange(full));
+        assertTrue(mt.hashesRange(childfull));
+        assertFalse(mt.hashesRange(invalid));
+
         assertHashEquals(fullval, mt.hash(full));
         assertHashEquals(childfullval, mt.hash(childfull));
-        assertNull(mt.hash(invalid));
     }
 
     @Test
@@ -326,7 +327,7 @@ public class MerkleTreeTest
         }
 
         // validate the tree
-        TreeRangeIterator ranges = mt.invalids();
+        TreeRangeIterator ranges = mt.rangeIterator();
         for (TreeRange range : ranges)
             range.addHash(new RowHash(range.right, new byte[0], 0));
 
@@ -355,7 +356,7 @@ public class MerkleTreeTest
         mt.split(tok(6));
         mt.split(tok(10));
 
-        ranges = mt.invalids();
+        ranges = mt.rangeIterator();
         ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2
         ranges.next().addAll(new HIterator(6)); // (4,6]
         ranges.next().addAll(new HIterator(8)); // (6,8]
@@ -372,7 +373,7 @@ public class MerkleTreeTest
         mt2.split(tok(9));
         mt2.split(tok(11));
 
-        ranges = mt2.invalids();
+        ranges = mt2.rangeIterator();
         ranges.next().addAll(new HIterator(2)); // (-1,2]
         ranges.next().addAll(new HIterator(4)); // (2,4]
         ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2
@@ -395,19 +396,33 @@ public class MerkleTreeTest
         // populate and validate the tree
         mt.maxsize(256);
         mt.init();
-        for (TreeRange range : mt.invalids())
+        for (TreeRange range : mt.rangeIterator())
             range.addAll(new HIterator(range.right));
 
         byte[] initialhash = mt.hash(full);
 
         DataOutputBuffer out = new DataOutputBuffer();
-        MerkleTree.serializer.serialize(mt, out, MessagingService.current_version);
+        mt.serialize(out, MessagingService.current_version);
         byte[] serialized = out.toByteArray();
 
-        DataInputPlus in = new DataInputBuffer(serialized);
-        MerkleTree restored = MerkleTree.serializer.deserialize(in, MessagingService.current_version);
+        MerkleTree restoredOnHeap =
+            MerkleTree.deserialize(new DataInputBuffer(serialized), false, MessagingService.current_version);
+        MerkleTree restoredOffHeap =
+            MerkleTree.deserialize(new DataInputBuffer(serialized), true, MessagingService.current_version);
+        MerkleTree movedOffHeap = mt.tryMoveOffHeap();
+
+        assertHashEquals(initialhash, restoredOnHeap.hash(full));
+        assertHashEquals(initialhash, restoredOffHeap.hash(full));
+        assertHashEquals(initialhash, movedOffHeap.hash(full));
+
+        assertEquals(mt, restoredOnHeap);
+        assertEquals(mt, restoredOffHeap);
+        assertEquals(mt, movedOffHeap);
+
+        assertEquals(restoredOnHeap, restoredOffHeap);
+        assertEquals(restoredOnHeap, movedOffHeap);
 
-        assertHashEquals(initialhash, restored.hash(full));
+        assertEquals(restoredOffHeap, movedOffHeap);
     }
 
     @Test
@@ -420,9 +435,9 @@ public class MerkleTreeTest
         mt2.init();
 
         // add dummy hashes to both trees
-        for (TreeRange range : mt.invalids())
+        for (TreeRange range : mt.rangeIterator())
             range.addAll(new HIterator(range.right));
-        for (TreeRange range : mt2.invalids())
+        for (TreeRange range : mt2.rangeIterator())
             range.addAll(new HIterator(range.right));
 
         TreeRange leftmost = null;
@@ -431,7 +446,7 @@ public class MerkleTreeTest
         mt.maxsize(maxsize + 2); // give some room for splitting
 
         // split the leftmost
-        Iterator<TreeRange> ranges = mt.invalids();
+        Iterator<TreeRange> ranges = mt.rangeIterator();
         leftmost = ranges.next();
         mt.split(leftmost.right);
 
@@ -465,18 +480,18 @@ public class MerkleTreeTest
         byte[] h2 = "hjkl".getBytes();
 
         // add dummy hashes to both trees
-        for (TreeRange tree : ltree.invalids())
+        for (TreeRange tree : ltree.rangeIterator())
         {
             tree.addHash(new RowHash(range.right, h1, h1.length));
         }
-        for (TreeRange tree : rtree.invalids())
+        for (TreeRange tree : rtree.rangeIterator())
         {
             tree.addHash(new RowHash(range.right, h2, h2.length));
         }
 
         List<TreeRange> diffs = MerkleTree.difference(ltree, rtree);
         assertEquals(Lists.newArrayList(range), diffs);
-        assertEquals(MerkleTree.FULLY_INCONSISTENT, MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte) 0)));
+        assertEquals(MerkleTree.FULLY_INCONSISTENT, MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeRange(ltree.fullRange.left, ltree.fullRange.right, (byte)0)));
     }
 
     /**
@@ -499,11 +514,11 @@ public class MerkleTreeTest
 
 
         // add dummy hashes to both trees
-        for (TreeRange tree : ltree.invalids())
+        for (TreeRange tree : ltree.rangeIterator())
         {
             tree.addHash(new RowHash(range.right, h1, h1.length));
         }
-        for (TreeRange tree : rtree.invalids())
+        for (TreeRange tree : rtree.rangeIterator())
         {
             tree.addHash(new RowHash(range.right, h2, h2.length));
         }
@@ -533,7 +548,7 @@ public class MerkleTreeTest
             while (depth.equals(dstack.peek()))
             {
                 // consume the stack
-                hash = Hashable.binaryHash(hstack.pop(), hash);
+                hash = FBUtilities.xor(hstack.pop(), hash);
                 depth = dstack.pop() - 1;
             }
             dstack.push(depth);
diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java
index b40f6c4..a1f6068 100644
--- a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java
+++ b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java
@@ -34,7 +34,6 @@ import org.apache.cassandra.io.util.DataInputBuffer;
 import org.apache.cassandra.io.util.DataOutputBuffer;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.service.StorageService;
-import org.apache.cassandra.utils.MerkleTree.Hashable;
 import org.apache.cassandra.utils.MerkleTree.RowHash;
 import org.apache.cassandra.utils.MerkleTree.TreeRange;
 import org.apache.cassandra.utils.MerkleTrees.TreeRangeIterator;
@@ -43,7 +42,7 @@ import static org.junit.Assert.*;
 
 public class MerkleTreesTest
 {
-    public static byte[] DUMMY = "blah".getBytes();
+    private static final byte[] DUMMY = HashingUtils.newMessageDigest("SHA-256").digest("dummy".getBytes());
 
     /**
      * If a test assumes that the tree is 8 units wide, then it should set this value
@@ -193,7 +192,7 @@ public class MerkleTreesTest
         Iterator<TreeRange> ranges;
 
         // (zero, zero]
-        ranges = mts.invalids();
+        ranges = mts.rangeIterator();
         assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next());
         assertFalse(ranges.hasNext());
 
@@ -203,7 +202,7 @@ public class MerkleTreesTest
         mts.split(tok(6));
         mts.split(tok(3));
         mts.split(tok(5));
-        ranges = mts.invalids();
+        ranges = mts.rangeIterator();
         assertEquals(new Range<>(tok(6), tok(-1)), ranges.next());
         assertEquals(new Range<>(tok(-1), tok(2)), ranges.next());
         assertEquals(new Range<>(tok(2), tok(3)), ranges.next());
@@ -245,11 +244,6 @@ public class MerkleTreesTest
         // (zero,two] (two,four] (four, zero]
         mts.split(tok(4));
         mts.split(tok(2));
-        assertNull(mts.hash(left));
-        assertNull(mts.hash(partial));
-        assertNull(mts.hash(right));
-        assertNull(mts.hash(linvalid));
-        assertNull(mts.hash(rinvalid));
 
         // validate the range
         mts.get(tok(2)).hash(val);
@@ -280,10 +274,6 @@ public class MerkleTreesTest
         mts.split(tok(2));
         mts.split(tok(6));
         mts.split(tok(1));
-        assertNull(mts.hash(full));
-        assertNull(mts.hash(lchild));
-        assertNull(mts.hash(rchild));
-        assertNull(mts.hash(invalid));
 
         // validate the range
         mts.get(tok(1)).hash(val);
@@ -315,9 +305,6 @@ public class MerkleTreesTest
         mts.split(tok(4));
         mts.split(tok(2));
         mts.split(tok(1));
-        assertNull(mts.hash(full));
-        assertNull(mts.hash(childfull));
-        assertNull(mts.hash(invalid));
 
         // validate the range
         mts.get(tok(1)).hash(val);
@@ -349,7 +336,7 @@ public class MerkleTreesTest
         }
 
         // validate the tree
-        TreeRangeIterator ranges = mts.invalids();
+        TreeRangeIterator ranges = mts.rangeIterator();
         for (TreeRange range : ranges)
             range.addHash(new RowHash(range.right, new byte[0], 0));
 
@@ -378,13 +365,16 @@ public class MerkleTreesTest
         mts.split(tok(6));
         mts.split(tok(10));
 
-        ranges = mts.invalids();
-        ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2
-        ranges.next().addAll(new HIterator(6)); // (4,6]
-        ranges.next().addAll(new HIterator(8)); // (6,8]
+        int seed = 123456789;
+
+        Random random1 = new Random(seed);
+        ranges = mts.rangeIterator();
+        ranges.next().addAll(new HIterator(random1, 2, 4)); // (-1,4]: depth 2
+        ranges.next().addAll(new HIterator(random1, 6)); // (4,6]
+        ranges.next().addAll(new HIterator(random1, 8)); // (6,8]
         ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,10]
-        ranges.next().addAll(new HIterator(12)); // (10,12]
-        ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2
+        ranges.next().addAll(new HIterator(random1, 12)); // (10,12]
+        ranges.next().addAll(new HIterator(random1, 14, -1)); // (12,-1]: depth 2
 
 
         mts2.split(tok(8));
@@ -395,15 +385,16 @@ public class MerkleTreesTest
         mts2.split(tok(9));
         mts2.split(tok(11));
 
-        ranges = mts2.invalids();
-        ranges.next().addAll(new HIterator(2)); // (-1,2]
-        ranges.next().addAll(new HIterator(4)); // (2,4]
-        ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2
+        Random random2 = new Random(seed);
+        ranges = mts2.rangeIterator();
+        ranges.next().addAll(new HIterator(random2, 2)); // (-1,2]
+        ranges.next().addAll(new HIterator(random2, 4)); // (2,4]
+        ranges.next().addAll(new HIterator(random2, 6, 8)); // (4,8]: depth 2
         ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,9]
         ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (9,10]
         ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (10,11]: depth 4
-        ranges.next().addAll(new HIterator(12)); // (11,12]: depth 4
-        ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2
+        ranges.next().addAll(new HIterator(random2, 12)); // (11,12]: depth 4
+        ranges.next().addAll(new HIterator(random2, 14, -1)); // (12,-1]: depth 2
 
         byte[] mthash = mts.hash(full);
         byte[] mt2hash = mts2.hash(full);
@@ -425,7 +416,7 @@ public class MerkleTreesTest
 
         // populate and validate the tree
         mts.init();
-        for (TreeRange range : mts.invalids())
+        for (TreeRange range : mts.rangeIterator())
             range.addAll(new HIterator(range.right));
 
         byte[] initialhash = mts.hash(first);
@@ -456,11 +447,15 @@ public class MerkleTreesTest
         mts.init();
         mts2.init();
 
+        int seed = 123456789;
         // add dummy hashes to both trees
-        for (TreeRange range : mts.invalids())
-            range.addAll(new HIterator(range.right));
-        for (TreeRange range : mts2.invalids())
-            range.addAll(new HIterator(range.right));
+        Random random1 = new Random(seed);
+        for (TreeRange range : mts.rangeIterator())
+            range.addAll(new HIterator(random1, range.right));
+
+        Random random2 = new Random(seed);
+        for (TreeRange range : mts2.rangeIterator())
+            range.addAll(new HIterator(random2, range.right));
 
         TreeRange leftmost = null;
         TreeRange middle = null;
@@ -468,7 +463,7 @@ public class MerkleTreesTest
         mts.maxsize(fullRange(), maxsize + 2); // give some room for splitting
 
         // split the leftmost
-        Iterator<TreeRange> ranges = mts.invalids();
+        Iterator<TreeRange> ranges = mts.rangeIterator();
         leftmost = ranges.next();
         mts.split(leftmost.right);
 
@@ -504,7 +499,7 @@ public class MerkleTreesTest
             while (depth.equals(dstack.peek()))
             {
                 // consume the stack
-                hash = Hashable.binaryHash(hstack.pop(), hash);
+                hash = FBUtilities.xor(hstack.pop(), hash);
                 depth = dstack.pop()-1;
             }
             dstack.push(depth);
@@ -516,25 +511,47 @@ public class MerkleTreesTest
 
     public static class HIterator extends AbstractIterator<RowHash>
     {
-        private Iterator<Token> tokens;
+        private final Random random;
+        private final Iterator<Token> tokens;
 
-        public HIterator(int... tokens)
+        HIterator(int... tokens)
         {
-            List<Token> tlist = new LinkedList<Token>();
+            this(new Random(), tokens);
+        }
+
+        HIterator(Random random, int... tokens)
+        {
+            List<Token> tlist = new ArrayList<>(tokens.length);
             for (int token : tokens)
                 tlist.add(tok(token));
             this.tokens = tlist.iterator();
+            this.random = random;
         }
 
         public HIterator(Token... tokens)
         {
-            this.tokens = Arrays.asList(tokens).iterator();
+            this(new Random(), tokens);
+        }
+
+        HIterator(Random random, Token... tokens)
+        {
+            this(random, Arrays.asList(tokens).iterator());
+        }
+
+        private HIterator(Random random, Iterator<Token> tokens)
+        {
+            this.random = random;
+            this.tokens = tokens;
         }
 
         public RowHash computeNext()
         {
             if (tokens.hasNext())
-                return new RowHash(tokens.next(), DUMMY, DUMMY.length);
+            {
+                byte[] digest = new byte[32];
+                random.nextBytes(digest);
+                return new RowHash(tokens.next(), digest, 12345L);
+            }
             return endOfData();
         }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org