This is an automated email from the ASF dual-hosted git repository.

maedhroz pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit 2531cb045897d5b771f79039d194a1f679d8629a
Author: Mike Adamson <madam...@datastax.com>
AuthorDate: Thu Jul 13 11:24:55 2023 +0100

    Fix concurrency in bbtree reader by cloning state
    
    patch by Mike Adamson; reviewed by Andrés de la Peña and Caleb Rackliffe 
for CASSANDRA-18669
---
 .../disk/v1/bbtree/BlockBalancedTreeReader.java    |  33 ++--
 .../disk/v1/bbtree/BlockBalancedTreeWalker.java    | 152 +++++++++-------
 .../v1/bbtree/BlockBalancedTreeReaderTest.java     | 192 +++++++++++++--------
 .../sai/disk/v1/bbtree/BlockBalancedTreeTest.java  |  54 +++---
 4 files changed, 265 insertions(+), 166 deletions(-)

diff --git 
a/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReader.java
 
b/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReader.java
index 59271e3c14..53cac195e0 100644
--- 
a/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReader.java
+++ 
b/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReader.java
@@ -78,17 +78,17 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         this.indexContext = indexContext;
         this.postingsFile = postingsFile;
         this.postingsIndex = new BlockBalancedTreePostingsIndex(postingsFile, 
treePostingsRoot);
-        leafOrderMapBitsRequired = 
DirectWriter.unsignedBitsRequired(state.maxPointsInLeafNode - 1);
+        leafOrderMapBitsRequired = 
DirectWriter.unsignedBitsRequired(maxValuesInLeafNode - 1);
     }
 
     public int getBytesPerValue()
     {
-        return state.bytesPerValue;
+        return bytesPerValue;
     }
 
     public long getPointCount()
     {
-        return state.valueCount;
+        return valueCount;
     }
 
     @Override
@@ -101,7 +101,7 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
     @SuppressWarnings({"resource", "RedundantSuppression"})
     public PostingList intersect(IntersectVisitor visitor, 
QueryEventListener.BalancedTreeEventListener listener, QueryContext context)
     {
-        Relation relation = visitor.compare(state.minPackedValue, 
state.maxPackedValue);
+        Relation relation = visitor.compare(minPackedValue, maxPackedValue);
 
         if (relation == Relation.CELL_OUTSIDE_QUERY)
         {
@@ -113,7 +113,6 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         IndexInput treeInput = 
IndexFileUtils.instance.openInput(treeIndexFile);
         IndexInput postingsInput = 
IndexFileUtils.instance.openInput(postingsFile);
         IndexInput postingsSummaryInput = 
IndexFileUtils.instance.openInput(postingsFile);
-        state.reset();
 
         Intersection intersection = relation == Relation.CELL_INSIDE_QUERY
                                     ? new Intersection(treeInput, 
postingsInput, postingsSummaryInput, listener, context)
@@ -131,6 +130,7 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         private final Stopwatch queryExecutionTimer = 
Stopwatch.createStarted();
         final QueryContext context;
 
+        final TraversalState state;
         final IndexInput treeInput;
         final IndexInput postingsInput;
         final IndexInput postingsSummaryInput;
@@ -140,12 +140,13 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         Intersection(IndexInput treeInput, IndexInput postingsInput, 
IndexInput postingsSummaryInput,
                      QueryEventListener.BalancedTreeEventListener listener, 
QueryContext context)
         {
+            this.state = newTraversalState();
             this.treeInput = treeInput;
             this.postingsInput = postingsInput;
             this.postingsSummaryInput = postingsSummaryInput;
             this.listener = listener;
             this.context = context;
-            postingLists = new PriorityQueue<>(state.numLeaves, COMPARATOR);
+            postingLists = new PriorityQueue<>(numLeaves, COMPARATOR);
         }
 
         public PostingList execute()
@@ -247,14 +248,14 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         {
             super(treeInput, postingsInput, postingsSummaryInput, listener, 
context);
             this.visitor = visitor;
-            this.packedValue = new byte[state.bytesPerValue];
-            this.origIndex = new short[state.maxPointsInLeafNode];
+            this.packedValue = new byte[bytesPerValue];
+            this.origIndex = new short[maxValuesInLeafNode];
         }
 
         @Override
         public void executeInternal() throws IOException
         {
-            collectPostingLists(state.minPackedValue, state.maxPackedValue);
+            collectPostingLists(minPackedValue, maxPackedValue);
         }
 
         private void collectPostingLists(byte[] minPackedValue, byte[] 
maxPackedValue) throws IOException
@@ -320,8 +321,8 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
             if (BlockBalancedTreeWriter.DEBUG)
             {
                 // make sure cellMin <= splitValue <= cellMax:
-                assert ByteArrayUtil.compareUnsigned(minPackedValue, 0, 
splitValue, 0, state.bytesPerValue) <= 0 :"bytesPerValue=" + 
state.bytesPerValue;
-                assert ByteArrayUtil.compareUnsigned(maxPackedValue, 0, 
splitValue, 0, state.bytesPerValue) >= 0 : "bytesPerValue=" + 
state.bytesPerValue;
+                assert ByteArrayUtil.compareUnsigned(minPackedValue, 0, 
splitValue, 0, bytesPerValue) <= 0 :"bytesPerValue=" + bytesPerValue;
+                assert ByteArrayUtil.compareUnsigned(maxPackedValue, 0, 
splitValue, 0, bytesPerValue) >= 0 : "bytesPerValue=" + bytesPerValue;
             }
 
             // Recurse on left subtree:
@@ -346,8 +347,8 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
         private FixedBitSet buildPostingsFilter(IndexInput in, int count, 
IntersectVisitor visitor, short[] origIndex) throws IOException
         {
             int commonPrefixLength = readCommonPrefixLength(in);
-            return commonPrefixLength == state.bytesPerValue ? 
buildPostingsFilterForSingleValueLeaf(count, visitor, origIndex)
-                                                             : 
buildPostingsFilterForMultiValueLeaf(commonPrefixLength, in, count, visitor, 
origIndex);
+            return commonPrefixLength == bytesPerValue ? 
buildPostingsFilterForSingleValueLeaf(count, visitor, origIndex)
+                                                       : 
buildPostingsFilterForMultiValueLeaf(commonPrefixLength, in, count, visitor, 
origIndex);
         }
 
         private FixedBitSet buildPostingsFilterForMultiValueLeaf(int 
commonPrefixLength,
@@ -362,7 +363,7 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
             commonPrefixLength++;
             int i;
 
-            FixedBitSet fixedBitSet = new 
FixedBitSet(state.maxPointsInLeafNode);
+            FixedBitSet fixedBitSet = new FixedBitSet(maxValuesInLeafNode);
 
             for (i = 0; i < count; )
             {
@@ -370,7 +371,7 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
                 final int runLen = Byte.toUnsignedInt(in.readByte());
                 for (int j = 0; j < runLen; ++j)
                 {
-                    in.readBytes(packedValue, commonPrefixLength, 
state.bytesPerValue - commonPrefixLength);
+                    in.readBytes(packedValue, commonPrefixLength, 
bytesPerValue - commonPrefixLength);
                     final int rowIDIndex = origIndex[i + j];
                     if (visitor.contains(packedValue))
                         fixedBitSet.set(rowIDIndex);
@@ -385,7 +386,7 @@ public class BlockBalancedTreeReader extends 
BlockBalancedTreeWalker implements
 
         private FixedBitSet buildPostingsFilterForSingleValueLeaf(int count, 
IntersectVisitor visitor, final short[] origIndex)
         {
-            FixedBitSet fixedBitSet = new 
FixedBitSet(state.maxPointsInLeafNode);
+            FixedBitSet fixedBitSet = new FixedBitSet(maxValuesInLeafNode);
 
             // All the values in the leaf are the same, so we only
             // need to visit once then set the bits for the relevant indexes
diff --git 
a/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeWalker.java
 
b/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeWalker.java
index ebfcd4ce5f..5a01b81f09 100644
--- 
a/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeWalker.java
+++ 
b/src/java/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeWalker.java
@@ -21,6 +21,8 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.util.Arrays;
 
+import javax.annotation.concurrent.NotThreadSafe;
+
 import com.google.common.annotations.VisibleForTesting;
 
 import org.agrona.collections.IntArrayList;
@@ -46,7 +48,15 @@ import org.apache.lucene.util.BytesRef;
 public class BlockBalancedTreeWalker implements Closeable
 {
     final FileHandle treeIndexFile;
-    final TraversalState state;
+    final int bytesPerValue;
+    final int numLeaves;
+    final int treeDepth;
+    final byte[] minPackedValue;
+    final byte[] maxPackedValue;
+    final long valueCount;
+    final int maxValuesInLeafNode;
+    final byte[] packedIndex;
+    final long memoryUsage;
 
     BlockBalancedTreeWalker(FileHandle treeIndexFile, long treeIndexRoot)
     {
@@ -58,7 +68,35 @@ public class BlockBalancedTreeWalker implements Closeable
             SAICodecUtils.validate(indexInput);
             indexInput.seek(treeIndexRoot);
 
-            state = new TraversalState(indexInput);
+            maxValuesInLeafNode = indexInput.readVInt();
+            bytesPerValue = indexInput.readVInt();
+
+            // Read index:
+            numLeaves = indexInput.readVInt();
+            assert numLeaves > 0;
+            treeDepth = indexInput.readVInt();
+            minPackedValue = new byte[bytesPerValue];
+            maxPackedValue = new byte[bytesPerValue];
+
+            indexInput.readBytes(minPackedValue, 0, bytesPerValue);
+            indexInput.readBytes(maxPackedValue, 0, bytesPerValue);
+
+            if (ByteArrayUtil.compareUnsigned(minPackedValue, 0, 
maxPackedValue, 0, bytesPerValue) > 0)
+            {
+                String message = String.format("Min packed value %s is > max 
packed value %s.",
+                                               new BytesRef(minPackedValue), 
new BytesRef(maxPackedValue));
+                throw new CorruptIndexException(message, indexInput);
+            }
+
+            valueCount = indexInput.readVLong();
+
+            int numBytes = indexInput.readVInt();
+            packedIndex = new byte[numBytes];
+            indexInput.readBytes(packedIndex, 0, numBytes);
+
+            memoryUsage = ObjectSizes.sizeOfArray(packedIndex) +
+                          ObjectSizes.sizeOfArray(minPackedValue) +
+                          ObjectSizes.sizeOfArray(maxPackedValue);
         }
         catch (Throwable t)
         {
@@ -67,9 +105,52 @@ public class BlockBalancedTreeWalker implements Closeable
         }
     }
 
+    @VisibleForTesting
+    public BlockBalancedTreeWalker(DataInput indexInput, long treeIndexRoot) 
throws IOException
+    {
+        treeIndexFile = null;
+
+        indexInput.skipBytes(treeIndexRoot);
+
+        maxValuesInLeafNode = indexInput.readVInt();
+        bytesPerValue = indexInput.readVInt();
+
+        // Read index:
+        numLeaves = indexInput.readVInt();
+        assert numLeaves > 0;
+        treeDepth = indexInput.readVInt();
+        minPackedValue = new byte[bytesPerValue];
+        maxPackedValue = new byte[bytesPerValue];
+
+        indexInput.readBytes(minPackedValue, 0, bytesPerValue);
+        indexInput.readBytes(maxPackedValue, 0, bytesPerValue);
+
+        if (ByteArrayUtil.compareUnsigned(minPackedValue, 0, maxPackedValue, 
0, bytesPerValue) > 0)
+        {
+            String message = String.format("Min packed value %s is > max 
packed value %s.",
+                                           new BytesRef(minPackedValue), new 
BytesRef(maxPackedValue));
+            throw new CorruptIndexException(message, indexInput);
+        }
+
+        valueCount = indexInput.readVLong();
+
+        int numBytes = indexInput.readVInt();
+        packedIndex = new byte[numBytes];
+        indexInput.readBytes(packedIndex, 0, numBytes);
+
+        memoryUsage = ObjectSizes.sizeOfArray(packedIndex) +
+                      ObjectSizes.sizeOfArray(minPackedValue) +
+                      ObjectSizes.sizeOfArray(maxPackedValue);
+    }
+
     public long memoryUsage()
     {
-        return state.memoryUsage;
+        return memoryUsage;
+    }
+
+    public TraversalState newTraversalState()
+    {
+        return new TraversalState();
     }
 
     @Override
@@ -80,11 +161,10 @@ public class BlockBalancedTreeWalker implements Closeable
 
     void traverse(TraversalCallback callback)
     {
-        state.reset();
-        traverse(callback, new IntArrayList());
+        traverse(newTraversalState(), callback, new IntArrayList());
     }
 
-    private void traverse(TraversalCallback callback, IntArrayList pathToRoot)
+    private void traverse(TraversalState state, TraversalCallback callback, 
IntArrayList pathToRoot)
     {
         if (state.atLeafNode())
         {
@@ -101,11 +181,11 @@ public class BlockBalancedTreeWalker implements Closeable
             currentPath.add(state.nodeID);
 
             state.pushLeft();
-            traverse(callback, currentPath);
+            traverse(state, callback, currentPath);
             state.pop();
 
             state.pushRight();
-            traverse(callback, currentPath);
+            traverse(state, callback, currentPath);
             state.pop();
         }
     }
@@ -143,17 +223,9 @@ public class BlockBalancedTreeWalker implements Closeable
      *    2[0-16]   3[16-32]
      * </pre>
      */
-    final static class TraversalState
+    @NotThreadSafe
+    final class TraversalState
     {
-        final int bytesPerValue;
-        final int numLeaves;
-        final int treeDepth;
-        final byte[] minPackedValue;
-        final byte[] maxPackedValue;
-        final long valueCount;
-        final int maxPointsInLeafNode;
-        final long memoryUsage;
-
         // used to read the packed index byte[]
         final ByteArrayDataInput dataInput;
         // holds the minimum (left most) leaf block file pointer for each 
level we've recursed to:
@@ -170,60 +242,18 @@ public class BlockBalancedTreeWalker implements Closeable
         @VisibleForTesting
         int maxLevel;
 
-        TraversalState(DataInput dataInput) throws IOException
+        private TraversalState()
         {
-            maxPointsInLeafNode = dataInput.readVInt();
-            bytesPerValue = dataInput.readVInt();
-
-            // Read index:
-            numLeaves = dataInput.readVInt();
-            assert numLeaves > 0;
-            treeDepth = dataInput.readVInt();
-            minPackedValue = new byte[bytesPerValue];
-            maxPackedValue = new byte[bytesPerValue];
-
-            dataInput.readBytes(minPackedValue, 0, bytesPerValue);
-            dataInput.readBytes(maxPackedValue, 0, bytesPerValue);
-
-            if (ByteArrayUtil.compareUnsigned(minPackedValue, 0, 
maxPackedValue, 0, bytesPerValue) > 0)
-            {
-                String message = String.format("Min packed value %s is > max 
packed value %s.",
-                                               new BytesRef(minPackedValue), 
new BytesRef(maxPackedValue));
-                throw new CorruptIndexException(message, dataInput);
-            }
-
-            valueCount = dataInput.readVLong();
-
-            int numBytes = dataInput.readVInt();
-            byte[] packedIndex = new byte[numBytes];
-            dataInput.readBytes(packedIndex, 0, numBytes);
-
             nodeID = 1;
             level = 0;
             leafBlockFPStack = new long[treeDepth];
             leftNodePositions = new int[treeDepth];
             rightNodePositions = new int[treeDepth];
             splitValuesStack = new byte[treeDepth][];
-
-            memoryUsage = ObjectSizes.sizeOfArray(packedIndex) +
-                          ObjectSizes.sizeOfArray(minPackedValue) +
-                          ObjectSizes.sizeOfArray(maxPackedValue) +
-                          ObjectSizes.sizeOfArray(leafBlockFPStack) +
-                          ObjectSizes.sizeOfArray(leftNodePositions) +
-                          ObjectSizes.sizeOfArray(rightNodePositions) +
-                          ObjectSizes.sizeOfArray(splitValuesStack) * 
bytesPerValue;
-
             this.dataInput = new ByteArrayDataInput(packedIndex);
             readNodeData(false);
         }
 
-        public void reset()
-        {
-            nodeID = 1;
-            level = 0;
-            dataInput.setPosition(0);
-        }
-
         public void pushLeft()
         {
             int nodePosition = leftNodePositions[level];
diff --git 
a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReaderTest.java
 
b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReaderTest.java
index 984a064c54..59e54c5708 100644
--- 
a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReaderTest.java
+++ 
b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeReaderTest.java
@@ -17,6 +17,13 @@
  */
 package org.apache.cassandra.index.sai.disk.v1.bbtree;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
 import org.junit.Before;
 import org.junit.Test;
 
@@ -33,6 +40,8 @@ import org.apache.cassandra.index.sai.plan.Expression;
 import org.apache.cassandra.index.sai.postings.PostingList;
 import org.apache.cassandra.index.sai.utils.SAIRandomizedTester;
 import org.apache.cassandra.io.util.FileHandle;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Throwables;
 import org.apache.lucene.index.PointValues.Relation;
 import org.apache.lucene.util.NumericUtils;
 
@@ -98,22 +107,16 @@ public class BlockBalancedTreeReaderTest extends 
SAIRandomizedTester
         final BlockBalancedTreeRamBuffer buffer = new 
BlockBalancedTreeRamBuffer(Integer.BYTES);
 
         byte[] scratch = new byte[4];
-        for (int docID = 0; docID < numRows; docID++)
+        for (int rowID = 0; rowID < numRows; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
         }
 
-        final BlockBalancedTreeReader reader = finishAndOpenReader(4, buffer);
-
-        Expression expression = new Expression(indexContext);
-        expression.add(Operator.GT, Int32Type.instance.decompose(444));
-        expression.add(Operator.LT, Int32Type.instance.decompose(555));
-        PostingList intersection = performIntersection(reader, 
BlockBalancedTreeQueries.balancedTreeQueryFrom(expression, 4));
-        assertNotNull(intersection);
-        assertEquals(110, intersection.size());
-        for (long posting = 445; posting < 555; posting++)
-            assertEquals(posting, intersection.nextPosting());
+        try (BlockBalancedTreeReader reader = finishAndOpenReader(4, buffer))
+        {
+            assertRange(reader, 445, 555);
+        }
     }
 
     @Test
@@ -123,32 +126,33 @@ public class BlockBalancedTreeReaderTest extends 
SAIRandomizedTester
         final BlockBalancedTreeRamBuffer buffer = new 
BlockBalancedTreeRamBuffer(Integer.BYTES);
 
         byte[] scratch = new byte[4];
-        for (int docID = 0; docID < numRows; docID++)
+        for (int rowID = 0; rowID < numRows; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
         }
 
-        final BlockBalancedTreeReader reader = finishAndOpenReader(2, buffer);
-
-        PostingList intersection = performIntersection(reader, NONE_MATCH);
-        assertNull(intersection);
-
-        intersection = performIntersection(reader, ALL_MATCH);
-        assertEquals(numRows, intersection.size());
-        assertEquals(100, intersection.advance(100));
-        assertEquals(200, intersection.advance(200));
-        assertEquals(300, intersection.advance(300));
-        assertEquals(400, intersection.advance(400));
-        assertEquals(401, intersection.advance(401));
-        long expectedRowID = 402;
-        for (long id = intersection.nextPosting(); expectedRowID < 500; id = 
intersection.nextPosting())
+        try (BlockBalancedTreeReader reader = finishAndOpenReader(2, buffer))
         {
-            assertEquals(expectedRowID++, id);
-        }
-        assertEquals(PostingList.END_OF_STREAM, intersection.advance(numRows + 
1));
+            PostingList intersection = performIntersection(reader, NONE_MATCH);
+            assertNull(intersection);
+
+            intersection = performIntersection(reader, ALL_MATCH);
+            assertEquals(numRows, intersection.size());
+            assertEquals(100, intersection.advance(100));
+            assertEquals(200, intersection.advance(200));
+            assertEquals(300, intersection.advance(300));
+            assertEquals(400, intersection.advance(400));
+            assertEquals(401, intersection.advance(401));
+            long expectedRowID = 402;
+            for (long id = intersection.nextPosting(); expectedRowID < 500; id 
= intersection.nextPosting())
+            {
+                assertEquals(expectedRowID++, id);
+            }
+            assertEquals(PostingList.END_OF_STREAM, 
intersection.advance(numRows + 1));
 
-        intersection.close();
+            intersection.close();
+        }
     }
 
     @Test
@@ -162,41 +166,42 @@ public class BlockBalancedTreeReaderTest extends 
SAIRandomizedTester
         final BlockBalancedTreeRamBuffer buffer = new 
BlockBalancedTreeRamBuffer(Integer.BYTES);
         byte[] scratch = new byte[4];
 
-        for (int docID = 0; docID < 10; docID++)
+        for (int rowID = 0; rowID < 10; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
         }
 
-        for (int docID = 10; docID < 20; docID++)
+        for (int rowID = 10; rowID < 20; rowID++)
         {
             NumericUtils.intToSortableBytes(10, scratch, 0);
-            buffer.add(docID, scratch);
+            buffer.add(rowID, scratch);
         }
 
-        for (int docID = 20; docID < 30; docID++)
+        for (int rowID = 20; rowID < 30; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
         }
 
-        final BlockBalancedTreeReader reader = finishAndOpenReader(5, buffer);
-
-        PostingList postingList = performIntersection(reader, buildQuery(8, 
15));
-
-        assertEquals(8, postingList.nextPosting());
-        assertEquals(9, postingList.nextPosting());
-        assertEquals(10, postingList.nextPosting());
-        assertEquals(11, postingList.nextPosting());
-        assertEquals(12, postingList.nextPosting());
-        assertEquals(13, postingList.nextPosting());
-        assertEquals(14, postingList.nextPosting());
-        assertEquals(15, postingList.nextPosting());
-        assertEquals(16, postingList.nextPosting());
-        assertEquals(17, postingList.nextPosting());
-        assertEquals(18, postingList.nextPosting());
-        assertEquals(19, postingList.nextPosting());
-        assertEquals(PostingList.END_OF_STREAM, postingList.nextPosting());
+        try (BlockBalancedTreeReader reader = finishAndOpenReader(5, buffer))
+        {
+            PostingList postingList = performIntersection(reader, 
buildQuery(8, 15));
+
+            assertEquals(8, postingList.nextPosting());
+            assertEquals(9, postingList.nextPosting());
+            assertEquals(10, postingList.nextPosting());
+            assertEquals(11, postingList.nextPosting());
+            assertEquals(12, postingList.nextPosting());
+            assertEquals(13, postingList.nextPosting());
+            assertEquals(14, postingList.nextPosting());
+            assertEquals(15, postingList.nextPosting());
+            assertEquals(16, postingList.nextPosting());
+            assertEquals(17, postingList.nextPosting());
+            assertEquals(18, postingList.nextPosting());
+            assertEquals(19, postingList.nextPosting());
+            assertEquals(PostingList.END_OF_STREAM, postingList.nextPosting());
+        }
     }
 
     @Test
@@ -204,22 +209,73 @@ public class BlockBalancedTreeReaderTest extends 
SAIRandomizedTester
     {
         final BlockBalancedTreeRamBuffer buffer = new 
BlockBalancedTreeRamBuffer(Integer.BYTES);
         byte[] scratch = new byte[4];
-        for (int docID = 0; docID < 1000; docID++)
+        for (int rowID = 0; rowID < 1000; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
         }
         // add a gap between 1000 and 1100
-        for (int docID = 1000; docID < 2000; docID++)
+        for (int rowID = 1000; rowID < 2000; rowID++)
+        {
+            NumericUtils.intToSortableBytes(rowID + 100, scratch, 0);
+            buffer.add(rowID, scratch);
+        }
+
+        try (BlockBalancedTreeReader reader = finishAndOpenReader(50, buffer))
+        {
+            final PostingList intersection = performIntersection(reader, 
buildQuery(1017, 1096));
+            assertNull(intersection);
+        }
+    }
+
+    @Test
+    public void testConcurrentIntersectionsOnSameReader() throws Exception
+    {
+        int numRows = 1000;
+
+        final BlockBalancedTreeRamBuffer buffer = new 
BlockBalancedTreeRamBuffer(Integer.BYTES);
+
+        byte[] scratch = new byte[4];
+        for (int rowID = 0; rowID < numRows; rowID++)
         {
-            NumericUtils.intToSortableBytes(docID + 100, scratch, 0);
-            buffer.add(docID, scratch);
+            NumericUtils.intToSortableBytes(rowID, scratch, 0);
+            buffer.add(rowID, scratch);
+        }
+
+        try (BlockBalancedTreeReader reader = finishAndOpenReader(4, buffer))
+        {
+            int concurrency = 100;
+
+            ExecutorService executor = 
Executors.newFixedThreadPool(concurrency);
+            List<Future<?>> results = new ArrayList<>();
+            for (int thread = 0; thread < concurrency; thread++)
+            {
+                results.add(executor.submit(() -> assertRange(reader, 445, 
555)));
+            }
+            FBUtilities.waitOnFutures(results);
+            executor.shutdown();
         }
+    }
 
-        final BlockBalancedTreeReader reader = finishAndOpenReader(50, buffer);
+    @SuppressWarnings("SameParameterValue")
+    private void assertRange(BlockBalancedTreeReader reader, long lowerBound, 
long upperBound)
+    {
+        Expression expression = new Expression(indexContext);
+        expression.add(Operator.GT, Int32Type.instance.decompose(444));
+        expression.add(Operator.LT, Int32Type.instance.decompose(555));
 
-        final PostingList intersection = performIntersection(reader, 
buildQuery(1017, 1096));
-        assertNull(intersection);
+        try
+        {
+            PostingList intersection = performIntersection(reader, 
BlockBalancedTreeQueries.balancedTreeQueryFrom(expression, 4));
+            assertNotNull(intersection);
+            assertEquals(upperBound - lowerBound, intersection.size());
+            for (long posting = lowerBound; posting < upperBound; posting++)
+                assertEquals(posting, intersection.nextPosting());
+        }
+        catch (IOException e)
+        {
+            throw Throwables.unchecked(e);
+        }
     }
 
     private PostingList performIntersection(BlockBalancedTreeReader reader, 
BlockBalancedTreeReader.IntersectVisitor visitor)
diff --git 
a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeTest.java
 
b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeTest.java
index 710da74b91..d749b13e77 100644
--- 
a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeTest.java
+++ 
b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeTest.java
@@ -53,26 +53,33 @@ public class BlockBalancedTreeTest extends 
SAIRandomizedTester
     @Test
     public void testSingleLeaf() throws Exception
     {
-        BlockBalancedTreeWalker.TraversalState state = 
generateBalancedTree(100, 100, rowID -> rowID);
+        try (BlockBalancedTreeWalker walker = generateBalancedTree(100, 100, 
rowID -> rowID))
+        {
+            assertEquals(1, walker.numLeaves);
+            assertEquals(1, walker.treeDepth);
+            assertEquals(100, walker.valueCount);
+
+            BlockBalancedTreeWalker.TraversalState state = 
walker.newTraversalState();
 
-        assertEquals(1, state.numLeaves);
-        assertEquals(1, state.treeDepth);
-        assertEquals(100, state.valueCount);
-        assertTrue(state.atLeafNode());
+            assertTrue(state.atLeafNode());
 
-        recursiveAssertTraversal(state, -1);
+            recursiveAssertTraversal(state, -1);
 
-        assertEquals(state.treeDepth, state.maxLevel + 1);
+            assertEquals(walker.treeDepth, state.maxLevel + 1);
+        }
     }
 
     @Test
     public void testTreeWithSameValue() throws Exception
     {
-        BlockBalancedTreeWalker.TraversalState state = 
generateBalancedTree(100, 4, rowID -> 1);
+        try (BlockBalancedTreeWalker walker = generateBalancedTree(100, 4, 
rowID -> 1))
+        {
+            BlockBalancedTreeWalker.TraversalState state = 
walker.newTraversalState();
 
-        recursiveAssertTraversal(state, -1);
+            recursiveAssertTraversal(state, -1);
 
-        assertEquals(state.treeDepth, state.maxLevel + 1);
+            assertEquals(walker.treeDepth, state.maxLevel + 1);
+        }
     }
 
     @Test
@@ -83,14 +90,17 @@ public class BlockBalancedTreeTest extends 
SAIRandomizedTester
         {
             int numRows = leafSize * numLeaves;
 
-            BlockBalancedTreeWalker.TraversalState state = 
generateBalancedTree(numRows, leafSize, rowID -> rowID);
+            try (BlockBalancedTreeWalker walker = 
generateBalancedTree(numRows, leafSize, rowID -> rowID))
+            {
+                assertEquals(numLeaves, walker.numLeaves);
+                assertTrue(walker.treeDepth <= walker.numLeaves);
 
-            assertEquals(numLeaves, state.numLeaves);
-            assertTrue(state.treeDepth <= state.numLeaves);
+                BlockBalancedTreeWalker.TraversalState state = 
walker.newTraversalState();
 
-            recursiveAssertTraversal(state, -1);
+                recursiveAssertTraversal(state, -1);
 
-            assertEquals(state.treeDepth, state.maxLevel + 1);
+                assertEquals(walker.treeDepth, state.maxLevel + 1);
+            }
         }
     }
 
@@ -104,11 +114,14 @@ public class BlockBalancedTreeTest extends 
SAIRandomizedTester
             int leafSize = nextInt(2, 512);
             int numRows = nextInt(1000, 10000);
 
-            BlockBalancedTreeWalker.TraversalState state = 
generateBalancedTree(numRows, leafSize, rowID -> nextInt(0, numRows / 2));
+            try (BlockBalancedTreeWalker walker = 
generateBalancedTree(numRows, leafSize, rowID -> nextInt(0, numRows / 2)))
+            {
+                BlockBalancedTreeWalker.TraversalState state = 
walker.newTraversalState();
 
-            recursiveAssertTraversal(state, -1);
+                recursiveAssertTraversal(state, -1);
 
-            assertEquals(state.treeDepth, state.maxLevel + 1);
+                assertEquals(walker.treeDepth, state.maxLevel + 1);
+            }
         }
     }
 
@@ -134,14 +147,13 @@ public class BlockBalancedTreeTest extends 
SAIRandomizedTester
         }
     }
 
-    private BlockBalancedTreeWalker.TraversalState generateBalancedTree(int 
numRows, int leafSize, IntFunction<Integer> valueProvider) throws Exception
+    private BlockBalancedTreeWalker generateBalancedTree(int numRows, int 
leafSize, IntFunction<Integer> valueProvider) throws Exception
     {
         long treeOffset = writeBalancedTree(numRows, leafSize, valueProvider);
 
         DataInput input = dataOutput.toDataInput();
 
-        input.skipBytes(treeOffset);
-        return new BlockBalancedTreeWalker.TraversalState(input);
+        return new BlockBalancedTreeWalker(input, treeOffset);
     }
 
     private long writeBalancedTree(int numRows, int leafSize, 
IntFunction<Integer> valueProvider) throws Exception


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to