This is an automated email from the ASF dual-hosted git repository.

belliottsmith pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-accord.git

commit b81408ccbd84f90278b6327402db08831ed61bcf
Author: Benedict Elliott Smith <[email protected]>
AuthorDate: Sat May 16 16:52:50 2026 +0100

    Slice PreLoadContext to owned keys in MapReduceCommandStores
    
    patch by Benedict; reviewed by Ariel Weisberg for CASSANDRA-21836
---
 .../java/accord/impl/progresslog/HomeState.java    |   3 +-
 .../src/main/java/accord/local/CommandStores.java  | 202 ++++++++++++++-------
 .../java/accord/local/MapReduceCommandStores.java  |  11 +-
 .../local/MapReduceConsumeCommandStores.java       |   3 +-
 accord-core/src/main/java/accord/local/Node.java   |   3 +-
 .../src/main/java/accord/local/PreLoadContext.java |  47 +++++
 .../accord/primitives/AbstractUnseekableKeys.java  |   7 +-
 7 files changed, 202 insertions(+), 74 deletions(-)

diff --git a/accord-core/src/main/java/accord/impl/progresslog/HomeState.java 
b/accord-core/src/main/java/accord/impl/progresslog/HomeState.java
index ad781c77..60f2dc7f 100644
--- a/accord-core/src/main/java/accord/impl/progresslog/HomeState.java
+++ b/accord-core/src/main/java/accord/impl/progresslog/HomeState.java
@@ -27,6 +27,7 @@ import accord.coordinate.Outcome;
 import accord.local.Command;
 import accord.local.CommandStores;
 import accord.local.CommandStores.IncludingSpecificStoreSelector;
+import accord.local.CommandStores.LatentStoreSelector;
 import accord.local.SafeCommand;
 import accord.local.SafeCommandStore;
 import accord.primitives.ProgressToken;
@@ -219,7 +220,7 @@ abstract class HomeState extends BaseTxnState
 
         ProgressToken maxProgressToken = 
owner.savedProgressToken(txnId).merge(command);
         CallbackInvoker<ProgressToken, Outcome> invoker = 
invokeHomeCallback(owner, txnId, maxProgressToken, HomeState::recoverCallback);
-        CommandStores.StoreSelector reportTo = new 
IncludingSpecificStoreSelector(safeStore.commandStore().id());
+        LatentStoreSelector reportTo = new 
IncludingSpecificStoreSelector(safeStore.commandStore().id());
 
         if (tracing != null)
             tracing.trace(safeStore.commandStore(), "Invoking MaybeRecover 
with progress token %s", maxProgressToken);
diff --git a/accord-core/src/main/java/accord/local/CommandStores.java 
b/accord-core/src/main/java/accord/local/CommandStores.java
index b1eca072..69c05478 100644
--- a/accord-core/src/main/java/accord/local/CommandStores.java
+++ b/accord-core/src/main/java/accord/local/CommandStores.java
@@ -20,7 +20,6 @@ package accord.local;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
@@ -29,11 +28,9 @@ import java.util.Objects;
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
-import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import java.util.stream.Stream;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
@@ -45,6 +42,7 @@ import org.slf4j.LoggerFactory;
 import accord.api.Agent;
 import accord.api.AsyncExecutorFactory;
 import accord.api.AsyncExecutor;
+import accord.api.VisibleForImplementation;
 import accord.topology.EpochReady;
 import accord.api.DataStore;
 import accord.api.Journal;
@@ -109,9 +107,7 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
             @Override
             public StoreSelector refine(TxnId txnId, @Nullable Timestamp 
executeAt, Participants<?> participants)
             {
-                return snapshot -> StoreFinder.find(snapshot, participants)
-                                              .filter(snapshot, participants, 
txnId.epoch(), (executeAt != null ? executeAt : txnId).epoch())
-                                              .iterator(snapshot);
+                return new RestrictedStoreSelector(participants, 
txnId.epoch(), (executeAt != null ? executeAt : txnId).epoch());
             }
         }
 
@@ -124,10 +120,47 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
     public interface StoreSelector extends LatentStoreSelector
     {
         default StoreSelector refine(TxnId txnId, @Nullable Timestamp 
executeAt, Participants<?> participants) { return this; }
-        Iterator<CommandStore> select(Snapshot snapshot);
+        StoreSelection select(Snapshot snapshot);
+
+        /**
+         * Return the ranges we use to interact with this shard, restricted by 
epoch;
+         * return null if we do not intersect this shard on the specified 
epochs
+         */
+        Ranges ranges(ShardHolder shard);
+    }
+
+    public interface UnrestrictedStoreSelector extends StoreSelector
+    {
+        default Ranges ranges(ShardHolder shard) { return 
shard.ranges().all(); }
+    }
+
+    public static class RestrictedStoreSelector implements StoreSelector
+    {
+        final Unseekables<?> keysOrRanges;
+        final long minEpoch, maxEpoch;
+
+        public RestrictedStoreSelector(Unseekables<?> keysOrRanges, long 
minEpoch, long maxEpoch)
+        {
+            this.keysOrRanges = keysOrRanges;
+            this.minEpoch = minEpoch;
+            this.maxEpoch = maxEpoch;
+        }
+
+        public StoreSelection select(Snapshot snapshot)
+        {
+            return StoreFinder.find(snapshot, keysOrRanges);
+        }
+
+        public @Nullable Ranges ranges(ShardHolder shard)
+        {
+            Ranges ranges = shard.ranges().allBetween(minEpoch, maxEpoch);
+            if (ranges != shard.ranges.all() && 
!ranges.intersects(keysOrRanges))
+                return null;
+            return ranges;
+        }
     }
 
-    public static class IncludingSpecificStoreSelector implements StoreSelector
+    public static class IncludingSpecificStoreSelector implements 
LatentStoreSelector
     {
         final int storeId;
 
@@ -139,44 +172,92 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
         @Override
         public StoreSelector refine(TxnId txnId, @Nullable Timestamp 
executeAt, Participants<?> participants)
         {
-            return snapshot -> {
-                StoreFinder finder = StoreFinder.find(snapshot, participants)
-                                                .filter(snapshot, 
participants, txnId.epoch(), (executeAt != null ? executeAt : txnId).epoch());
-                finder.set(snapshot.byId.get(storeId));
-                return finder.iterator(snapshot);
-            };
-        }
+            return new RestrictedStoreSelector(participants, txnId.epoch(), 
(executeAt != null ? executeAt : txnId).epoch())
+            {
+                @Override
+                public StoreSelection select(Snapshot snapshot)
+                {
+                    StoreSelection selection = super.select(snapshot);
+                    int index = snapshot.byId.get(storeId);
+                    if (index >= 0)
+                        selection.set(index);
+                    return selection;
+                }
 
-        @Override
-        public Iterator<CommandStore> select(Snapshot snapshot)
-        {
-            return 
Collections.singletonList(snapshot.byId(storeId)).iterator();
+                @Nullable
+                @Override
+                public Ranges ranges(ShardHolder shard)
+                {
+                    // TODO (expected): accept minEpoch/maxEpoch for the store 
so we can safely slice here
+                    if (shard.store.id == storeId)
+                        return shard.ranges().all();
+                    return super.ranges(shard);
+                }
+            };
         }
     }
 
     // TODO (required): as we get more tables this will become expensive to 
allocate; we need to index first by prefix
-    public static class StoreFinder extends LargeBitSet implements 
IndexedQuadConsumer<Object, Object, Object, Object>, 
IndexedRangeQuadConsumer<Object, Object, Object, Object>
+    public static class StoreSelection extends LargeBitSet
     {
-        final int[] indexMap;
+        public StoreSelection(Snapshot snapshot)
+        {
+            super(snapshot.shards.length);
+        }
+
+        public static StoreSelection ids(Snapshot snapshot, IntStream ids)
+        {
+            StoreSelection selection = new StoreFinder(snapshot);
+            ids.forEach(id -> {
+                int index = snapshot.byId.get(id);
+                if (index >= 0)
+                    selection.set(index);
+            });
+            return selection;
+        }
 
-        private StoreFinder(int size, int[] indexMap)
+        public static StoreSelection allOf(Snapshot snapshot)
         {
-            super(size);
-            this.indexMap = indexMap;
+            StoreSelection selection = new StoreFinder(snapshot);
+            selection.setRange(0, snapshot.shards.length);
+            return selection;
+        }
+
+        public final Iterator<ShardHolder> iterator(Snapshot snapshot)
+        {
+            return new Iterator<>()
+            {
+                int i = firstSetBit();
+                @Override
+                public boolean hasNext()
+                {
+                    return i >= 0;
+                }
+
+                @Override
+                public ShardHolder next()
+                {
+                    ShardHolder next = snapshot.shards[i];
+                    i = nextSetBit(i + 1, -1);
+                    return next;
+                }
+            };
         }
+    }
+
+    public static class StoreFinder extends StoreSelection implements 
IndexedQuadConsumer<Object, Object, Object, Object>, 
IndexedRangeQuadConsumer<Object, Object, Object, Object>
+    {
+        final int[] indexMap;
 
         public StoreFinder(Snapshot snapshot)
         {
-            this(snapshot.shards.length, snapshot.indexForRange);
+            super(snapshot);
+            this.indexMap = snapshot.indexForRange;
         }
 
         public static StoreSelector selector(Unseekables<?> unseekables, long 
minEpoch, long maxEpoch)
         {
-            return snapshot -> {
-                StoreFinder finder = StoreFinder.find(snapshot, unseekables);
-                finder.filter(snapshot, unseekables, minEpoch, maxEpoch);
-                return finder.iterator(snapshot);
-            };
+            return new RestrictedStoreSelector(unseekables, minEpoch, 
maxEpoch);
         }
 
         public static StoreFinder find(Snapshot snapshot, Unseekables<?> 
unseekables)
@@ -215,27 +296,6 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
             return this;
         }
 
-        public Iterator<CommandStore> iterator(Snapshot snapshot)
-        {
-            return new Iterator<>()
-            {
-                int i = firstSetBit();
-                @Override
-                public boolean hasNext()
-                {
-                    return i >= 0;
-                }
-
-                @Override
-                public CommandStore next()
-                {
-                    CommandStore next = snapshot.shards[i].store;
-                    i = nextSetBit(i + 1, -1);
-                    return next;
-                }
-            };
-        }
-
         @Override
         public void accept(Object p1, Object p2, Object p3, Object p4, int 
index)
         {
@@ -317,12 +377,6 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
             return ranges;
         }
 
-        boolean filter(long minEpoch, long maxEpoch, Unseekables<?> 
unseekables)
-        {
-            Ranges shardRanges = ranges.allBetween(minEpoch, maxEpoch);
-            return shardRanges != ranges.all() && 
!shardRanges.intersects(unseekables);
-        }
-
         public String toString()
         {
             return store.id() + " " + ranges;
@@ -647,6 +701,11 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
             return shards[byId.get(id)].store;
         }
 
+        ShardHolder shardById(int id)
+        {
+            return shards[byId.get(id)];
+        }
+
         @Override
         public Iterator<ShardHolder> iterator()
         {
@@ -854,7 +913,7 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
 
     public AsyncChain<Void> forAll(String reason, Consumer<SafeCommandStore> 
forEach)
     {
-        return mapReduce(snapshot -> Stream.of(snapshot.shards).map(shard -> 
shard.store).iterator(), new MapReduceCommandStores<>(RoutingKeys.EMPTY)
+        return mapReduce((UnrestrictedStoreSelector) StoreSelection::allOf, 
new MapReduceCommandStores<>(RoutingKeys.EMPTY)
         {
             @Override public Void reduce(Void o1, Void o2) { return null; }
             @Override public TxnId primaryTxnId() { return null; }
@@ -923,30 +982,41 @@ public abstract class CommandStores implements 
AsyncExecutorFactory
 
     public <O> AsyncChain<O> mapReduce(IntStream commandStoreIds, 
MapReduceCommandStores<?, O> mapReduce)
     {
-        return mapReduce(snapshot -> 
commandStoreIds.mapToObj(snapshot::byId).iterator(), mapReduce);
+        return mapReduce((UnrestrictedStoreSelector) snapshot -> 
StoreSelection.ids(snapshot, commandStoreIds), mapReduce);
     }
 
     public <O> AsyncChain<O> mapReduce(StoreSelector selector, 
MapReduceCommandStores<?, O> mapReduceConsume)
     {
         Snapshot snapshot = current;
-        Iterator<CommandStore> stores = selector.select(snapshot);
+        StoreSelection selection = selector.select(snapshot);
         AsyncChain<O> chain = null;
-        while (stores.hasNext())
+        for (int i = selection.firstSetBit(); i >= 0 ; i = 
selection.nextSetBit(i + 1, -1))
         {
-            AsyncChain<O> next = mapReduceConsume.applyAsync(stores.next());
+            ShardHolder shard = snapshot.shards[i];
+            Ranges ranges = selector.ranges(shard);
+            if (ranges == null)
+                continue;
+
+            AsyncChain<O> next = mapReduceConsume.applyAsync(ranges, 
shard.store);
             if (next != null)
                 chain = chain != null ? AsyncChains.reduce(chain, next, 
mapReduceConsume) : next;
         }
         return chain == null ? AsyncChains.success(null) : chain;
     }
 
-    public <O> O mapReduceUnsafe(StoreSelector selector, 
Function<CommandStore, O> map, BiFunction<O, O, O> reduce, O accumulator)
+    @VisibleForImplementation
+    public <O> O mapReduceUnsafe(StoreSelector selector, BiFunction<Ranges, 
CommandStore, O> map, BiFunction<O, O, O> reduce, O accumulator)
     {
         Snapshot snapshot = current;
-        Iterator<CommandStore> stores = selector.select(snapshot);
-        while (stores.hasNext())
+        StoreSelection selection = selector.select(snapshot);
+        for (int i = selection.firstSetBit(); i >= 0 ; i = 
selection.nextSetBit(i + 1, -1))
         {
-            O next = map.apply(stores.next());
+            ShardHolder shard = snapshot.shards[i];
+            Ranges ranges = selector.ranges(shard);
+            if (ranges == null)
+                continue;
+
+            O next = map.apply(ranges, shard.store);
             accumulator = reduce.apply(accumulator, next);
         }
         return accumulator;
diff --git a/accord-core/src/main/java/accord/local/MapReduceCommandStores.java 
b/accord-core/src/main/java/accord/local/MapReduceCommandStores.java
index 96eb1cf5..b9534764 100644
--- a/accord-core/src/main/java/accord/local/MapReduceCommandStores.java
+++ b/accord-core/src/main/java/accord/local/MapReduceCommandStores.java
@@ -22,10 +22,13 @@ import javax.annotation.Nullable;
 
 import accord.api.Tracing;
 import accord.primitives.Participants;
+import accord.primitives.Ranges;
 import accord.primitives.Unseekables;
 import accord.utils.MapReduce;
 import accord.utils.async.AsyncChain;
 
+import static accord.primitives.Routables.Slice.Minimal;
+
 public abstract class MapReduceCommandStores<P extends Participants<?>, O> 
implements PreLoadContext, MapReduce<SafeCommandStore, O>
 {
     public final P scope;
@@ -50,14 +53,14 @@ public abstract class MapReduceCommandStores<P extends 
Participants<?>, O> imple
         return applyInternal(safeStore);
     }
 
-    public final AsyncChain<O> applyAsync(CommandStore commandStore)
+    public final AsyncChain<O> applyAsync(Ranges ranges, CommandStore 
commandStore)
     {
-        return applyAsyncInternal(commandStore);
+        return applyAsyncInternal(ranges, commandStore);
     }
 
-    protected AsyncChain<O> applyAsyncInternal(CommandStore commandStore)
+    protected AsyncChain<O> applyAsyncInternal(Ranges ranges, CommandStore 
commandStore)
     {
-        return commandStore.chain(this, this);
+        return commandStore.chain(slice(ranges, Minimal), this);
     }
 
     protected boolean supportsPartialRefusal()
diff --git 
a/accord-core/src/main/java/accord/local/MapReduceConsumeCommandStores.java 
b/accord-core/src/main/java/accord/local/MapReduceConsumeCommandStores.java
index 31190ce8..e5f7b9ee 100644
--- a/accord-core/src/main/java/accord/local/MapReduceConsumeCommandStores.java
+++ b/accord-core/src/main/java/accord/local/MapReduceConsumeCommandStores.java
@@ -22,6 +22,7 @@ import java.util.function.Function;
 import javax.annotation.Nullable;
 
 import accord.primitives.Participants;
+import accord.primitives.Ranges;
 import accord.primitives.TxnId;
 import accord.utils.MapReduceConsume;
 import accord.utils.async.AsyncChain;
@@ -38,7 +39,7 @@ public abstract class MapReduceConsumeCommandStores<P extends 
Participants<?>, O
         return new Delegate<>(this)
         {
             @Override
-            protected AsyncChain<O> applyAsyncInternal(CommandStore 
commandStore)
+            protected AsyncChain<O> applyAsyncInternal(Ranges ranges, 
CommandStore commandStore)
             {
                 return apply.apply(commandStore);
             }
diff --git a/accord-core/src/main/java/accord/local/Node.java 
b/accord-core/src/main/java/accord/local/Node.java
index 00022a60..785fe9c7 100644
--- a/accord-core/src/main/java/accord/local/Node.java
+++ b/accord-core/src/main/java/accord/local/Node.java
@@ -40,6 +40,7 @@ import accord.api.TopologyService;
 import accord.api.Tracing;
 import accord.coordinate.ExecuteTxn;
 import accord.impl.LocalDelivery;
+import accord.local.CommandStores.UnrestrictedStoreSelector;
 import accord.local.cfk.ExecuteTxnBacklog;
 import accord.messages.RemoteSuccess;
 import accord.messages.ReplyContext.NoReplyContext;
@@ -827,7 +828,7 @@ public class Node implements NodeCommandStoreService
     public AsyncChain<Void> updateMinHlc(long minHlc)
     {
         // TODO (required): command stores that are not ready due to bootstrap 
need to refresh their min HLC on bootstrap completion
-        StoreSelector selector = snapshot -> Stream.of(snapshot.shards).map(sh 
-> sh.store).iterator();
+        UnrestrictedStoreSelector selector = 
CommandStores.StoreSelection::allOf;
         return commandStores().mapReduce(selector, new 
MapReduceCommandStores<>(RoutingKeys.EMPTY)
         {
             @Override public Void reduce(Void o1, Void o2) { return null; }
diff --git a/accord-core/src/main/java/accord/local/PreLoadContext.java 
b/accord-core/src/main/java/accord/local/PreLoadContext.java
index 68de4a04..9d2d4431 100644
--- a/accord-core/src/main/java/accord/local/PreLoadContext.java
+++ b/accord-core/src/main/java/accord/local/PreLoadContext.java
@@ -19,9 +19,13 @@
 package accord.local;
 
 import accord.api.RoutingKey;
+import accord.api.VisibleForImplementation;
 import accord.local.cfk.CommandsForKey;
 import accord.primitives.AbstractUnseekableKeys;
+import accord.primitives.Ranges;
 import accord.primitives.Routable;
+import accord.primitives.Routables;
+import accord.primitives.Routables.Slice;
 import accord.primitives.RoutingKeys;
 import accord.primitives.Timestamp;
 import accord.primitives.TxnId;
@@ -92,6 +96,20 @@ public interface PreLoadContext
             consumer.accept(additionalTxnId);
     }
 
+    default PreLoadContext slice(Ranges ranges, Slice slice)
+    {
+        Unseekables<?> keys = keys();
+        int size = keys.size();
+        if (size == 0 || loadKeys() == NONE)
+            return this;
+
+        Unseekables<?> newKeys = keys.slice(ranges, slice);
+        if (newKeys == keys)
+            return this;
+
+        return new OverrideKeys(this, newKeys);
+    }
+
     /**
      * @return keys of the {@link CommandsForKey} objects that need to be 
loaded into memory before this operation is run
      */
@@ -159,6 +177,35 @@ public interface PreLoadContext
         return reason() + (txnIds.isEmpty() ? "" : " for " + txnIds) + 
(keys.isEmpty() ? "" : (txnIds.isEmpty() ? " for " : " and ") + keys());
     }
 
+    class Wrapped implements PreLoadContext
+    {
+        final PreLoadContext wrapped;
+        public Wrapped(PreLoadContext wrapped)
+        {
+            this.wrapped = wrapped;
+        }
+        @Nullable @Override public TxnId primaryTxnId() { return 
wrapped.primaryTxnId(); }
+        @Nullable @Override public TxnId additionalTxnId() { return 
wrapped.additionalTxnId(); }
+        @Override public Unseekables<?> keys() { return wrapped.keys(); }
+        @Override public LoadKeys loadKeys() { return wrapped.loadKeys(); }
+        @Override public LoadKeysFor loadKeysFor() { return 
wrapped.loadKeysFor(); }
+        @Override public Timestamp executeAt() { return wrapped.executeAt(); }
+        @Override public String reason() { return wrapped.reason(); }
+        @Override public String describe() { return wrapped.describe(); }
+    }
+
+    class OverrideKeys extends Wrapped
+    {
+        final Unseekables<?> keys;
+        public OverrideKeys(PreLoadContext wrapped, Unseekables<?> keys)
+        {
+            super(wrapped);
+            this.keys = keys;
+        }
+
+        @Override public Unseekables<?> keys() { return keys; }
+    }
+
     static PreLoadContext contextFor(@Nullable TxnId primary, @Nullable TxnId 
additional, Unseekables<?> keys, LoadKeys loadKeys, LoadKeysFor loadKeysFor, 
String reason)
     {
         Invariants.require(primary == null ? additional == null : 
!primary.equals(additional));
diff --git 
a/accord-core/src/main/java/accord/primitives/AbstractUnseekableKeys.java 
b/accord-core/src/main/java/accord/primitives/AbstractUnseekableKeys.java
index f569d765..ece1a336 100644
--- a/accord-core/src/main/java/accord/primitives/AbstractUnseekableKeys.java
+++ b/accord-core/src/main/java/accord/primitives/AbstractUnseekableKeys.java
@@ -57,7 +57,12 @@ implements Iterable<RoutingKey>, Unseekables<RoutingKey>, 
Participants<RoutingKe
     @Override
     public final int find(RoutingKey key, SortedArrays.Search search)
     {
-        return SortedArrays.binarySearch(keys, 0, keys.length, key, 
RoutingKey::compareTo, search);
+        return find(key, 0, keys.length, search);
+    }
+
+    public final int find(RoutingKey key, int start, int end, 
SortedArrays.Search search)
+    {
+        return SortedArrays.binarySearch(keys, start, end, key, 
RoutingKey::compareTo, search);
     }
 
     @Override


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to