You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by bl...@apache.org on 2022/10/21 10:11:25 UTC

[cassandra] 03/05: MemtableTrie using multiple buffers

This is an automated email from the ASF dual-hosted git repository.

blambov pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit 562cb26010659830dd1192939ac815a0f6cb3502
Author: Branimir Lambov <br...@datastax.com>
AuthorDate: Thu Nov 11 15:39:21 2021 +0200

    MemtableTrie using multiple buffers
    
    The replaces the size doubling and copying required to grow the trie
    with an allocation of a new buffer. This improves the cost of expansion
    at the expense of increasing individual read and write costs.
    
    patch by Branimir Lambov; reviewed by Jason Rutherglen, Jacek Lewandowski, Andres de la Peña and Caleb Rackliffe for CASSANDRA-17240
---
 .../cassandra/db/tries/MemtableReadTrie.java       |  94 ++++++++--
 .../apache/cassandra/db/tries/MemtableTrie.java    | 192 ++++++++++++---------
 .../cassandra/db/tries/MemtableTrieTestBase.java   |   2 +-
 3 files changed, 192 insertions(+), 96 deletions(-)

diff --git a/src/java/org/apache/cassandra/db/tries/MemtableReadTrie.java b/src/java/org/apache/cassandra/db/tries/MemtableReadTrie.java
index 073a99e16e..e9c1e150ec 100644
--- a/src/java/org/apache/cassandra/db/tries/MemtableReadTrie.java
+++ b/src/java/org/apache/cassandra/db/tries/MemtableReadTrie.java
@@ -188,20 +188,67 @@ public class MemtableReadTrie<T> extends Trie<T>
 
     volatile int root;
 
-    final UnsafeBuffer buffer;
+    /*
+     EXPANDABLE DATA STORAGE
+
+     The tries will need more and more space in buffers and content lists as they grow. Instead of using ArrayList-like
+     reallocation with copying, which may be prohibitively expensive for large buffers, we use a sequence of
+     buffers/content arrays that double in size on every expansion.
+
+     For a given address x the index of the buffer can be found with the following calculation:
+        index_of_most_significant_set_bit(x / min_size + 1)
+     (relying on sum (2^i) for i in [0, n-1] == 2^n - 1) which can be performed quickly on modern hardware.
+
+     Finding the offset within the buffer is then
+        x + min - (min << buffer_index)
+
+     The allocated space starts 256 bytes for the buffer and 16 entries for the content list.
+
+     Note that a buffer is not allowed to split 32-byte blocks (code assumes same buffer can be used for all bytes
+     inside the block).
+     */
+
+    static final int BUF_START_SHIFT = 8;
+    static final int BUF_START_SIZE = 1 << BUF_START_SHIFT;
+
+    static final int CONTENTS_START_SHIFT = 4;
+    static final int CONTENTS_START_SIZE = 1 << CONTENTS_START_SHIFT;
 
-    volatile AtomicReferenceArray<T> contentArray;
+    final UnsafeBuffer[] buffers;
+    final AtomicReferenceArray<T>[] contentArrays;
 
-    MemtableReadTrie(UnsafeBuffer buffer, AtomicReferenceArray<T> contentArray, int root)
+    MemtableReadTrie(UnsafeBuffer[] buffers, AtomicReferenceArray<T>[] contentArrays, int root)
     {
-        this.buffer = buffer;
-        this.contentArray = contentArray;
+        this.buffers = buffers;
+        this.contentArrays = contentArrays;
         this.root = root;
     }
 
     /*
      Buffer, content list and block management
      */
+    int getChunkIdx(int pos, int minChunkShift, int minChunkSize)
+    {
+        return 31 - minChunkShift - Integer.numberOfLeadingZeros(pos + minChunkSize);
+    }
+
+    int inChunkPointer(int pos, int chunkIndex, int minChunkSize)
+    {
+        return pos + minChunkSize - (minChunkSize << chunkIndex);
+    }
+
+    UnsafeBuffer getChunk(int pos)
+    {
+        int leadBit = getChunkIdx(pos, BUF_START_SHIFT, BUF_START_SIZE);
+        return buffers[leadBit];
+    }
+
+    int inChunkPointer(int pos)
+    {
+        int leadBit = getChunkIdx(pos, BUF_START_SHIFT, BUF_START_SIZE);
+        return inChunkPointer(pos, leadBit, BUF_START_SIZE);
+    }
+
 
     /**
      * Pointer offset for a node pointer.
@@ -213,19 +260,25 @@ public class MemtableReadTrie<T> extends Trie<T>
 
     final int getUnsignedByte(int pos)
     {
-        return buffer.getByte(pos) & 0xFF;
+        return getChunk(pos).getByte(inChunkPointer(pos)) & 0xFF;
     }
 
     final int getUnsignedShort(int pos)
     {
-        return buffer.getShort(pos) & 0xFFFF;
+        return getChunk(pos).getShort(inChunkPointer(pos)) & 0xFFFF;
     }
 
-    final int getInt(int pos) { return buffer.getInt(pos); }
+    final int getInt(int pos)
+    {
+        return getChunk(pos).getInt(inChunkPointer(pos));
+    }
 
     T getContent(int index)
     {
-        return contentArray.get(index);
+        int leadBit = getChunkIdx(index, CONTENTS_START_SHIFT, CONTENTS_START_SIZE);
+        int ofs = inChunkPointer(index, leadBit, CONTENTS_START_SIZE);
+        AtomicReferenceArray<T> array = contentArrays[leadBit];
+        return array.get(ofs);
     }
 
     /*
@@ -512,17 +565,19 @@ public class MemtableReadTrie<T> extends Trie<T>
                 return advance();
 
             // Jump directly to the chain's child.
+            UnsafeBuffer chunk = getChunk(node);
+            int inChunkNode = inChunkPointer(node);
             int bytesJumped = chainBlockLength(node) - 1;   // leave the last byte for incomingTransition
             if (receiver != null && bytesJumped > 0)
-                receiver.addPathBytes(buffer, node, bytesJumped);
+                receiver.addPathBytes(chunk, inChunkNode, bytesJumped);
             depth += bytesJumped;    // descendInto will add one
-            node += bytesJumped;
+            inChunkNode += bytesJumped;
 
             // inChunkNode is now positioned on the last byte of the chain.
             // Consume it to be the new state's incomingTransition.
-            int transition = getUnsignedByte(node++);
+            int transition = chunk.getByte(inChunkNode++) & 0xFF;
             // inChunkNode is now positioned on the child pointer.
-            int child = getInt(node);
+            int child = chunk.getInt(inChunkNode);
             return descendInto(child, transition);
         }
 
@@ -696,6 +751,9 @@ public class MemtableReadTrie<T> extends Trie<T>
 
         private int nextValidSparseTransition(int node, int data)
         {
+            UnsafeBuffer chunk = getChunk(node);
+            int inChunkNode = inChunkPointer(node);
+
             // Peel off the next index.
             int index = data % SPARSE_CHILD_COUNT;
             data = data / SPARSE_CHILD_COUNT;
@@ -705,20 +763,22 @@ public class MemtableReadTrie<T> extends Trie<T>
                 addBacktrack(node, data, depth);
 
             // Follow the transition.
-            int child = getInt(node + SPARSE_CHILDREN_OFFSET + index * 4);
-            int transition = getUnsignedByte(node + SPARSE_BYTES_OFFSET + index);
+            int child = chunk.getInt(inChunkNode + SPARSE_CHILDREN_OFFSET + index * 4);
+            int transition = chunk.getByte(inChunkNode + SPARSE_BYTES_OFFSET + index) & 0xFF;
             return descendInto(child, transition);
         }
 
         private int getChainTransition(int node)
         {
             // No backtracking needed.
-            int transition = getUnsignedByte(node);
+            UnsafeBuffer chunk = getChunk(node);
+            int inChunkNode = inChunkPointer(node);
+            int transition = chunk.getByte(inChunkNode) & 0xFF;
             int next = node + 1;
             if (offset(next) <= CHAIN_MAX_OFFSET)
                 return descendIntoChain(next, transition);
             else
-                return descendInto(getInt(node + 1), transition);
+                return descendInto(chunk.getInt(inChunkNode + 1), transition);
         }
 
         private int descendInto(int child, int transition)
diff --git a/src/java/org/apache/cassandra/db/tries/MemtableTrie.java b/src/java/org/apache/cassandra/db/tries/MemtableTrie.java
index 00d68fbf5a..45893fd247 100644
--- a/src/java/org/apache/cassandra/db/tries/MemtableTrie.java
+++ b/src/java/org/apache/cassandra/db/tries/MemtableTrie.java
@@ -25,14 +25,12 @@ import java.util.concurrent.atomic.AtomicReferenceArray;
 
 import com.google.common.annotations.VisibleForTesting;
 
-import org.slf4j.LoggerFactory;
-
 import org.agrona.concurrent.UnsafeBuffer;
 import org.apache.cassandra.config.CassandraRelevantProperties;
 import org.apache.cassandra.io.compress.BufferType;
+import org.apache.cassandra.io.util.FileUtils;
 import org.apache.cassandra.utils.bytecomparable.ByteSource;
 import org.apache.cassandra.utils.bytecomparable.ByteComparable;
-import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.ObjectSizes;
 import org.github.jamm.MemoryLayoutSpecification;
 
@@ -74,24 +72,24 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
 
     private final BufferType bufferType;    // on or off heap
 
-    private static final long EMPTY_SIZE_ON_HEAP; // for space calculations
-    private static final long EMPTY_SIZE_OFF_HEAP; // for space calculations
+    // constants for space calculations
+    private static final long EMPTY_SIZE_ON_HEAP;
+    private static final long EMPTY_SIZE_OFF_HEAP;
+    private static final long REFERENCE_ARRAY_ON_HEAP_SIZE = ObjectSizes.measureDeep(new AtomicReferenceArray(0));
 
     static
     {
         MemtableTrie<Object> empty = new MemtableTrie<>(BufferType.ON_HEAP);
-        EMPTY_SIZE_ON_HEAP = ObjectSizes.measureDeep(empty)
-                             - empty.contentArray.length() * MemoryLayoutSpecification.SPEC.getReferenceSize()
-                             - empty.buffer.capacity();
+        EMPTY_SIZE_ON_HEAP = ObjectSizes.measureDeep(empty);
         empty = new MemtableTrie<>(BufferType.OFF_HEAP);
-        EMPTY_SIZE_OFF_HEAP = ObjectSizes.measureDeep(empty)
-                              - empty.contentArray.length() * MemoryLayoutSpecification.SPEC.getReferenceSize()
-                              - empty.buffer.capacity();
+        EMPTY_SIZE_OFF_HEAP = ObjectSizes.measureDeep(empty);
     }
 
     public MemtableTrie(BufferType bufferType)
     {
-        super(new UnsafeBuffer(bufferType.allocate(INITIAL_BUFFER_CAPACITY)), new AtomicReferenceArray<>(16), NONE);
+        super(new UnsafeBuffer[31 - BUF_START_SHIFT],  // last one is 1G for a total of ~2G bytes
+              new AtomicReferenceArray[29 - CONTENTS_START_SHIFT],  // takes at least 4 bytes to write pointer to one content -> 4 times smaller than buffers
+              NONE);
         this.bufferType = bufferType;
         assert INITIAL_BUFFER_CAPACITY % BLOCK_SIZE == 0;
     }
@@ -114,40 +112,51 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         }
     }
 
+    final void putInt(int pos, int value)
+    {
+        getChunk(pos).putInt(inChunkPointer(pos), value);
+    }
+
+    final void putIntOrdered(int pos, int value)
+    {
+        getChunk(pos).putIntOrdered(inChunkPointer(pos), value);
+    }
+
+    final void putIntVolatile(int pos, int value)
+    {
+        getChunk(pos).putIntVolatile(inChunkPointer(pos), value);
+    }
+
+    final void putShort(int pos, short value)
+    {
+        getChunk(pos).putShort(inChunkPointer(pos), value);
+    }
+
+    final void putShortVolatile(int pos, short value)
+    {
+        getChunk(pos).putShort(inChunkPointer(pos), value);
+    }
+
+    final void putByte(int pos, byte value)
+    {
+        getChunk(pos).putByte(inChunkPointer(pos), value);
+    }
+
+
     private int allocateBlock() throws SpaceExhaustedException
     {
         // Note: If this method is modified, please run MemtableTrieTest.testOver1GSize to verify it acts correctly
         // close to the 2G limit.
         int v = allocatedPos;
-        if (buffer.capacity() == v)
-        {
-            int newSize;
-            if (v >= ALLOCATED_SIZE_THRESHOLD)
+        if (inChunkPointer(v) == 0)
             {
-                // we don't expect to write much after the threshold has been reached
-                // to avoid allocating too much space which will be left unused,
-                // grow by 10% of the limit, rounding up to BLOCK_SIZE
-                newSize = (v + ALLOCATED_SIZE_THRESHOLD / 10 + BLOCK_SIZE - 1) & -BLOCK_SIZE;
-                // If we do this repeatedly and the calculated size grows over 2G, it will overflow and result in a
-                // negative integer. In that case, cap it to a size that can be allocated.
-                if (newSize < 0)
-                {
-                    newSize = 0x7FFFFF00;   // 2G - 256 bytes
-                    if (newSize == allocatedPos)    // already at limit
+            int leadBit = getChunkIdx(v, BUF_START_SHIFT, BUF_START_SIZE);
+            if (leadBit == 31)
                         throw new SpaceExhaustedException();
-                    LoggerFactory.getLogger(getClass()).debug("Growing memtable trie to maximum size {}",
-                                                              FBUtilities.prettyPrintMemory(newSize));
-                }
-                else
-                    LoggerFactory.getLogger(getClass()).debug("Growing memtable trie by 10% over the {} limit to {}",
-                                                              FBUtilities.prettyPrintMemory(ALLOCATED_SIZE_THRESHOLD),
-                                                              FBUtilities.prettyPrintMemory(newSize));
-            } else
-                newSize = v * 2;
-
-            ByteBuffer newBuffer = bufferType.allocate(newSize);
-            buffer.getBytes(0, newBuffer, v);
-            buffer.wrap(newBuffer);
+
+            assert buffers[leadBit] == null;
+            ByteBuffer newBuffer = bufferType.allocate(BUF_START_SIZE << leadBit);
+            buffers[leadBit] = new UnsafeBuffer(newBuffer);
             // The above does not contain any happens-before enforcing writes, thus at this point the new buffer may be
             // invisible to any concurrent readers. Touching the volatile root pointer (which any new read must go
             // through) enforces a happens-before that makes it visible to all new reads (note: when the write completes
@@ -163,21 +172,37 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
     private int addContent(T value)
     {
         int index = contentCount++;
-        if (index == contentArray.length())
+        int leadBit = getChunkIdx(index, CONTENTS_START_SHIFT, CONTENTS_START_SIZE);
+        int ofs = inChunkPointer(index, leadBit, CONTENTS_START_SIZE);
+        AtomicReferenceArray<T> array = contentArrays[leadBit];
+        if (array == null)
         {
-            AtomicReferenceArray<T> newContent = new AtomicReferenceArray<>(index * 2);
-            for (int i = 0; i < contentArray.length(); ++i)
-                newContent.lazySet(i, contentArray.get(i));
-            contentArray = newContent;  // This is a volatile set, hence all previous stores must become visible
+            assert ofs == 0;
+            contentArrays[leadBit] = array = new AtomicReferenceArray<>(CONTENTS_START_SIZE << leadBit);
         }
-        contentArray.lazySet(index, value); // no need for a volatile set here; at this point the item is not referenced
+        array.lazySet(ofs, value); // no need for a volatile set here; at this point the item is not referenced
                                             // by any node in the trie, and a volatile set will be made to reference it.
         return index;
     }
 
     private void setContent(int index, T value)
     {
-        contentArray.set(index, value);
+        int leadBit = getChunkIdx(index, CONTENTS_START_SHIFT, CONTENTS_START_SIZE);
+        int ofs = inChunkPointer(index, leadBit, CONTENTS_START_SIZE);
+        AtomicReferenceArray<T> array = contentArrays[leadBit];
+        array.set(ofs, value);
+    }
+
+    public void discardBuffers()
+    {
+        if (bufferType == BufferType.ON_HEAP)
+            return; // no cleaning needed
+
+        for (UnsafeBuffer b : buffers)
+        {
+            if (b != null)
+                FileUtils.clean(b.byteBuffer());
+        }
     }
 
     // Write methods
@@ -212,7 +237,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
                 // If this is the last character in a Chain block, we can modify the child in-place
                 if (trans == getUnsignedByte(node))
                 {
-                    buffer.putIntVolatile(node + 1, newChild);
+                    putIntVolatile(node + 1, newChild);
                     return node;
                 }
                 // else pass through
@@ -234,9 +259,9 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
             int tailPos = splitBlockPointerAddress(mid, splitNodeTailIndex(trans), SPLIT_OTHER_LEVEL_LIMIT);
             int tail = createEmptySplitNode();
             int childPos = splitBlockPointerAddress(tail, splitNodeChildIndex(trans), SPLIT_OTHER_LEVEL_LIMIT);
-            buffer.putInt(childPos, newChild);
-            buffer.putInt(tailPos, tail);
-            buffer.putIntVolatile(midPos, mid);
+            putInt(childPos, newChild);
+            putInt(tailPos, tail);
+            putIntVolatile(midPos, mid);
             return;
         }
 
@@ -246,13 +271,13 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         {
             tail = createEmptySplitNode();
             int childPos = splitBlockPointerAddress(tail, splitNodeChildIndex(trans), SPLIT_OTHER_LEVEL_LIMIT);
-            buffer.putInt(childPos, newChild);
-            buffer.putIntVolatile(tailPos, tail);
+            putInt(childPos, newChild);
+            putIntVolatile(tailPos, tail);
             return;
         }
 
         int childPos = splitBlockPointerAddress(tail, splitNodeChildIndex(trans), SPLIT_OTHER_LEVEL_LIMIT);
-        buffer.putIntVolatile(childPos, newChild);
+        putIntVolatile(childPos, newChild);
     }
 
     /**
@@ -270,7 +295,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
             final int existing = getUnsignedByte(node + SPARSE_BYTES_OFFSET + index);
             if (existing == trans)
             {
-                buffer.putIntVolatile(node + SPARSE_CHILDREN_OFFSET + index * 4, newChild);
+                putIntVolatile(node + SPARSE_CHILDREN_OFFSET + index * 4, newChild);
                 return node;
             }
             else if (existing < trans)
@@ -293,7 +318,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         }
 
         // Add a new transition. They are not kept in order, so append it at the first free position.
-        buffer.putByte(node + SPARSE_BYTES_OFFSET + childCount, (byte) trans);
+        putByte(node + SPARSE_BYTES_OFFSET + childCount, (byte) trans);
 
         // Update order word.
         int order = getUnsignedShort(node + SPARSE_ORDER_OFFSET);
@@ -308,11 +333,11 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         // correct value (see getSparseChild).
 
         // setting child enables reads to start seeing the new branch
-        buffer.putIntVolatile(node + SPARSE_CHILDREN_OFFSET + childCount * 4, newChild);
+        putIntVolatile(node + SPARSE_CHILDREN_OFFSET + childCount * 4, newChild);
 
         // some readers will decide whether to check the pointer based on the order word
         // write that volatile to make sure they see the new change too
-        buffer.putShortVolatile(node + SPARSE_ORDER_OFFSET,  (short) newOrder);
+        putShortVolatile(node + SPARSE_ORDER_OFFSET,  (short) newOrder);
         return node;
     }
 
@@ -348,7 +373,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         if (isNull(mid))
         {
             mid = createEmptySplitNode();
-            buffer.putInt(midPos, mid);
+            putInt(midPos, mid);
         }
 
         assert offset(mid) == SPLIT_OFFSET : "Invalid split node in trie";
@@ -357,12 +382,12 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         if (isNull(tail))
         {
             tail = createEmptySplitNode();
-            buffer.putInt(tailPos, tail);
+            putInt(tailPos, tail);
         }
 
         assert offset(tail) == SPLIT_OFFSET : "Invalid split node in trie";
         int childPos = splitBlockPointerAddress(tail, splitNodeChildIndex(trans), SPLIT_OTHER_LEVEL_LIMIT);
-        buffer.putInt(childPos, newChild);
+        putInt(childPos, newChild);
     }
 
     /**
@@ -416,11 +441,11 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         }
 
         int node = allocateBlock() + SPARSE_OFFSET;
-        buffer.putByte(node + SPARSE_BYTES_OFFSET + 0,  (byte) byte1);
-        buffer.putByte(node + SPARSE_BYTES_OFFSET + 1,  (byte) byte2);
-        buffer.putInt(node + SPARSE_CHILDREN_OFFSET + 0 * 4, child1);
-        buffer.putInt(node + SPARSE_CHILDREN_OFFSET + 1 * 4, child2);
-        buffer.putShort(node + SPARSE_ORDER_OFFSET,  (short) (1 * 6 + 0));
+        putByte(node + SPARSE_BYTES_OFFSET + 0,  (byte) byte1);
+        putByte(node + SPARSE_BYTES_OFFSET + 1,  (byte) byte2);
+        putInt(node + SPARSE_CHILDREN_OFFSET + 0 * 4, child1);
+        putInt(node + SPARSE_CHILDREN_OFFSET + 1 * 4, child2);
+        putShort(node + SPARSE_ORDER_OFFSET,  (short) (1 * 6 + 0));
         // Note: this does not need a volatile write as it is a new node, returning a new pointer, which needs to be
         // put in an existing node or the root. That action ends in a happens-before enforcing write.
         return node;
@@ -434,8 +459,8 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
     private int createNewChainNode(int transitionByte, int newChild) throws SpaceExhaustedException
     {
         int newNode = allocateBlock() + LAST_POINTER_OFFSET - 1;
-        buffer.putByte(newNode, (byte) transitionByte);
-        buffer.putInt(newNode + 1, newChild);
+        putByte(newNode, (byte) transitionByte);
+        putInt(newNode + 1, newChild);
         // Note: this does not need a volatile write as it is a new node, returning a new pointer, which needs to be
         // put in an existing node or the root. That action ends in a happens-before enforcing write.
         return newNode;
@@ -449,7 +474,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         {
             // attach as a new character in child node
             int newNode = newChild - 1;
-            buffer.putByte(newNode, (byte) transitionByte);
+            putByte(newNode, (byte) transitionByte);
             return newNode;
         }
 
@@ -474,17 +499,17 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
             // creating the embedded node may overwrite information that is still needed by concurrent readers or the
             // mutation process itself.
             node = (child & -BLOCK_SIZE) | PREFIX_OFFSET;
-            buffer.putByte(node + PREFIX_FLAGS_OFFSET, (byte) offset);
+            putByte(node + PREFIX_FLAGS_OFFSET, (byte) offset);
         }
         else
         {
             // Full prefix node
             node = allocateBlock() + PREFIX_OFFSET;
-            buffer.putByte(node + PREFIX_FLAGS_OFFSET, (byte) 0xFF);
-            buffer.putInt(node + PREFIX_POINTER_OFFSET, child);
+            putByte(node + PREFIX_FLAGS_OFFSET, (byte) 0xFF);
+            putInt(node + PREFIX_POINTER_OFFSET, child);
         }
 
-        buffer.putInt(node + PREFIX_CONTENT_OFFSET, contentIndex);
+        putInt(node + PREFIX_CONTENT_OFFSET, contentIndex);
         return node;
     }
 
@@ -497,7 +522,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
         if (!isEmbeddedPrefixNode(node))
         {
             // This attaches the child branch and makes it reachable -- the write must be volatile.
-            buffer.putIntVolatile(node + PREFIX_POINTER_OFFSET, child);
+            putIntVolatile(node + PREFIX_POINTER_OFFSET, child);
             return node;
         }
         else
@@ -686,7 +711,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
             {
                 if (existingContentIndex != -1)
                 {
-                    final T existingContent = contentArray.get(existingContentIndex);
+                    final T existingContent = getContent(existingContentIndex);
                     T combinedContent = transformer.apply(existingContent, mutationContent);
                     assert (combinedContent != null) : "Transformer cannot be used to remove content.";
                     setContent(existingContentIndex, combinedContent);
@@ -743,7 +768,7 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
 
             // Otherwise modify in place
             if (updatedPostContentNode != existingPostContentNode) // to use volatile write but also ensure we don't corrupt embedded nodes
-                buffer.putIntVolatile(existingPreContentNode + PREFIX_POINTER_OFFSET, updatedPostContentNode);
+                putIntVolatile(existingPreContentNode + PREFIX_POINTER_OFFSET, updatedPostContentNode);
             assert contentIndex == getInt(existingPreContentNode + PREFIX_CONTENT_OFFSET) : "Unexpected change of content index.";
             return existingPreContentNode;
         }
@@ -960,7 +985,9 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
     public long sizeOnHeap()
     {
         return contentCount * MemoryLayoutSpecification.SPEC.getReferenceSize() +
-               (bufferType == BufferType.ON_HEAP ? allocatedPos + EMPTY_SIZE_ON_HEAP : EMPTY_SIZE_OFF_HEAP);
+               REFERENCE_ARRAY_ON_HEAP_SIZE * getChunkIdx(contentCount, CONTENTS_START_SHIFT, CONTENTS_START_SIZE) +
+               (bufferType == BufferType.ON_HEAP ? allocatedPos + EMPTY_SIZE_ON_HEAP : EMPTY_SIZE_OFF_HEAP) +
+               REFERENCE_ARRAY_ON_HEAP_SIZE * getChunkIdx(allocatedPos, BUF_START_SHIFT, BUF_START_SIZE);
     }
 
     @Override
@@ -994,9 +1021,18 @@ public class MemtableTrie<T> extends MemtableReadTrie<T>
     {
         int bufferOverhead = 0;
         if (bufferType == BufferType.ON_HEAP)
-            bufferOverhead = buffer.capacity() - this.allocatedPos;
+        {
+            int pos = this.allocatedPos;
+            UnsafeBuffer buffer = getChunk(pos);
+            if (buffer != null)
+                bufferOverhead = buffer.capacity() - inChunkPointer(pos);
+        }
 
-        int contentOverhead = (contentArray.length() - contentCount) * MemoryLayoutSpecification.SPEC.getReferenceSize();
+        int index = contentCount;
+        int leadBit = getChunkIdx(index, CONTENTS_START_SHIFT, CONTENTS_START_SIZE);
+        int ofs = inChunkPointer(index, leadBit, CONTENTS_START_SIZE);
+        AtomicReferenceArray<T> contentArray = contentArrays[leadBit];
+        int contentOverhead = ((contentArray != null ? contentArray.length() : 0) - ofs) * MemoryLayoutSpecification.SPEC.getReferenceSize();
 
         return bufferOverhead + contentOverhead;
     }
diff --git a/test/unit/org/apache/cassandra/db/tries/MemtableTrieTestBase.java b/test/unit/org/apache/cassandra/db/tries/MemtableTrieTestBase.java
index f7d309b874..23143cdc35 100644
--- a/test/unit/org/apache/cassandra/db/tries/MemtableTrieTestBase.java
+++ b/test/unit/org/apache/cassandra/db/tries/MemtableTrieTestBase.java
@@ -303,7 +303,7 @@ public abstract class MemtableTrieTestBase
                             .mapToInt(src1 -> ByteComparable.length(src1, VERSION))
                             .sum();
         long ts = ObjectSizes.measureDeep(content);
-        long onh = ObjectSizes.measureDeep(trie.contentArray);
+        long onh = ObjectSizes.measureDeep(trie.contentArrays);
         System.out.format("Trie size on heap %,d off heap %,d measured %,d keys %,d treemap %,d\n",
                           trie.sizeOnHeap(), trie.sizeOffHeap(), onh, keysize, ts);
         System.out.format("per entry on heap %.2f off heap %.2f measured %.2f keys %.2f treemap %.2f\n",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org