This is an automated email from the ASF dual-hosted git repository.

dcapwell pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 61b3d5a54b Add support for BEGIN TRANSACTION to allow mutations that 
touch multiple partitions
61b3d5a54b is described below

commit 61b3d5a54befe440044ad86159a62fee487229eb
Author: David Capwell <dcapw...@apache.org>
AuthorDate: Mon Aug 25 14:08:15 2025 -0700

    Add support for BEGIN TRANSACTION to allow mutations that touch multiple 
partitions
    
    patch by David Capwell; reviewed by Ariel Weisberg for CASSANDRA-20857
---
 CHANGES.txt                                        |   1 +
 .../cassandra/cql3/statements/CQL3CasRequest.java  |   6 +-
 .../cql3/statements/ModificationStatement.java     |  42 ++++---
 .../cql3/statements/TransactionStatement.java      |  50 +++++++-
 .../serializers/AbstractSortedCollector.java       |   8 ++
 .../service/consensus/TransactionalMode.java       |   8 ++
 .../distributed/test/accord/AccordCQLTestBase.java |  63 ++++++++++
 .../cql3/AccordInteropMultiNodeTableWalkBase.java  |  15 +++
 .../test/cql3/SingleNodeTableWalkTest.java         |  41 ++++---
 .../distributed/test/cql3/StatefulASTBase.java     | 114 +++++-------------
 .../cassandra/distributed/util/DriverUtils.java    | 129 +++++++++++++++++++++
 .../fuzz/topology/AccordTopologyMixupTest.java     |   2 +-
 .../cassandra/harry/model/ASTSingleTableModel.java |   4 +-
 .../harry/model/ASTSingleTableModelTest.java       |  35 ++++++
 .../org/apache/cassandra/utils/ASTGenerators.java  |   5 +-
 15 files changed, 391 insertions(+), 132 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index b61cd947e0..08ae2c6162 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 5.1
+ * Add support for BEGIN TRANSACTION to allow mutations that touch multiple 
partitions (CASSANDRA-20857)
  * AutoRepair: Safeguard Full repair against disk protection(CASSANDRA-20045)
  * BEGIN TRANSACTION crashes if a mutation touches multiple rows 
(CASSANDRA-20844)
  * Fix version range check in MessagingService.getVersionOrdinal 
(CASSANDRA-20842)
diff --git a/src/java/org/apache/cassandra/cql3/statements/CQL3CasRequest.java 
b/src/java/org/apache/cassandra/cql3/statements/CQL3CasRequest.java
index cc1680b19a..080a705ffa 100644
--- a/src/java/org/apache/cassandra/cql3/statements/CQL3CasRequest.java
+++ b/src/java/org/apache/cassandra/cql3/statements/CQL3CasRequest.java
@@ -558,15 +558,13 @@ public class CQL3CasRequest implements CASRequest
             // see CASSANDRA-18337
             ModificationStatement modification = update.stmt.forTxn();
             QueryOptions options = update.options;
-            TxnWrite.Fragment fragment = 
modification.getTxnWriteFragment(idx++, state, options, partitionKey);
-            fragments.add(fragment);
+            fragments.addAll(modification.getTxnWriteFragment(idx++, state, 
options, partitionKey));
         }
         for (RangeDeletion rangeDeletion : rangeDeletions)
         {
             ModificationStatement modification = rangeDeletion.stmt;
             QueryOptions options = rangeDeletion.options;
-            TxnWrite.Fragment fragment = 
modification.getTxnWriteFragment(idx++, state, options, partitionKey);
-            fragments.add(fragment);
+            fragments.addAll(modification.getTxnWriteFragment(idx++, state, 
options, partitionKey));
         }
         return fragments;
     }
diff --git 
a/src/java/org/apache/cassandra/cql3/statements/ModificationStatement.java 
b/src/java/org/apache/cassandra/cql3/statements/ModificationStatement.java
index fa7aaaebaa..4544926fb9 100644
--- a/src/java/org/apache/cassandra/cql3/statements/ModificationStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/ModificationStatement.java
@@ -39,6 +39,7 @@ import com.google.common.collect.Lists;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import accord.utils.Invariants;
 import org.apache.cassandra.auth.Permission;
 import org.apache.cassandra.cql3.Attributes;
 import org.apache.cassandra.cql3.CQLStatement;
@@ -886,15 +887,15 @@ public abstract class ModificationStatement implements 
CQLStatement.SingleKeyspa
         }
     }
 
-    public PartitionUpdate getTxnUpdate(ClientState state, QueryOptions 
options)
+    public List<PartitionUpdate> getTxnUpdate(ClientState state, QueryOptions 
options)
     {
         List<? extends IMutation> mutations = getMutations(state, options, 
false, 0, 0, new Dispatcher.RequestTime(0, 0));
-        // TODO: Temporary fix for CASSANDRA-20079
         if (mutations.isEmpty())
-            return PartitionUpdate.emptyUpdate(metadata, 
metadata.partitioner.decorateKey(ByteBufferUtil.EMPTY_BYTE_BUFFER));
-        if (mutations.size() != 1)
-            throw new IllegalArgumentException("When running withing a 
transaction, modification statements may only mutate a single partition");
-        return 
Iterables.getOnlyElement(mutations.get(0).getPartitionUpdates());
+            return Collections.emptyList();
+        List<PartitionUpdate> updates = new ArrayList<>(mutations.size());
+        for (IMutation m : mutations)
+            updates.addAll(m.getPartitionUpdates());
+        return updates;
     }
 
     private static List<TxnReferenceOperation> 
getTxnReferenceOps(List<ReferenceOperation> operations, QueryOptions options)
@@ -948,20 +949,33 @@ public abstract class ModificationStatement implements 
CQLStatement.SingleKeyspa
         return operations.allSubstitutions();
     }
 
-    public TxnWrite.Fragment getTxnWriteFragment(int index, ClientState state, 
QueryOptions options, PartitionKey partitionKey)
+    public List<TxnWrite.Fragment> getTxnWriteFragment(int index, ClientState 
state, QueryOptions options, PartitionKey partitionKey)
     {
-        PartitionUpdate baseUpdate = getTxnUpdate(state, options);
-        TxnReferenceOperations referenceOps = getTxnReferenceOps(options, 
state);
-        long timestamp = attrs.isTimestampSet() ? 
attrs.getTimestamp(TxnWrite.NO_TIMESTAMP, options) : TxnWrite.NO_TIMESTAMP;
-        return new TxnWrite.Fragment(partitionKey, index, baseUpdate, 
referenceOps, timestamp);
+        return getTxnWriteFragment(index, state, options, baseUpdate -> {
+            
Invariants.require(baseUpdate.partitionKey().equals(partitionKey.partitionKey()),
 "PartitionUpdate generated a partition key different than the one expected");
+            return partitionKey;
+        });
+    }
+
+    public List<TxnWrite.Fragment> getTxnWriteFragment(int index, ClientState 
state, QueryOptions options, KeyCollector keyCollector)
+    {
+        return getTxnWriteFragment(index, state, options, baseUpdate -> 
keyCollector.collect(baseUpdate.metadata(), baseUpdate.partitionKey()));
     }
 
-    public TxnWrite.Fragment getTxnWriteFragment(int index, ClientState state, 
QueryOptions options, KeyCollector keyCollector)
+    private List<TxnWrite.Fragment> getTxnWriteFragment(int index, ClientState 
state, QueryOptions options, java.util.function.Function<PartitionUpdate, 
PartitionKey> keyCollector)
     {
-        PartitionUpdate baseUpdate = getTxnUpdate(state, options);
+        List<PartitionUpdate> baseUpdates = getTxnUpdate(state, options);
         TxnReferenceOperations referenceOps = getTxnReferenceOps(options, 
state);
         long timestamp = attrs.isTimestampSet() ? 
attrs.getTimestamp(TxnWrite.NO_TIMESTAMP, options) : TxnWrite.NO_TIMESTAMP;
-        return new 
TxnWrite.Fragment(keyCollector.collect(baseUpdate.metadata(), 
baseUpdate.partitionKey()), index, baseUpdate, referenceOps, timestamp);
+        if (baseUpdates.size() == 1)
+        {
+            PartitionUpdate baseUpdate = baseUpdates.get(0);
+            return Collections.singletonList(new 
TxnWrite.Fragment(keyCollector.apply(baseUpdate), index, baseUpdate, 
referenceOps, timestamp));
+        }
+        List<TxnWrite.Fragment> fragments = new 
ArrayList<>(baseUpdates.size());
+        for (PartitionUpdate baseUpdate : baseUpdates)
+            fragments.add(new 
TxnWrite.Fragment(keyCollector.apply(baseUpdate), index, baseUpdate, 
referenceOps, timestamp));
+        return fragments;
     }
 
     final void addUpdates(UpdatesCollector collector,
diff --git 
a/src/java/org/apache/cassandra/cql3/statements/TransactionStatement.java 
b/src/java/org/apache/cassandra/cql3/statements/TransactionStatement.java
index 6ca18b9f02..2a96bcc16d 100644
--- a/src/java/org/apache/cassandra/cql3/statements/TransactionStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/TransactionStatement.java
@@ -27,6 +27,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.SortedSet;
+import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import javax.annotation.Nullable;
@@ -36,6 +37,8 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 
+import org.slf4j.LoggerFactory;
+
 import accord.api.Key;
 import accord.primitives.Keys;
 import accord.primitives.Routable.Domain;
@@ -90,6 +93,7 @@ import org.apache.cassandra.tcm.Epoch;
 import org.apache.cassandra.transport.Dispatcher;
 import org.apache.cassandra.transport.messages.ResultMessage;
 import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.NoSpamLogger;
 
 import static accord.primitives.Txn.Kind.Read;
 import static com.google.common.base.Preconditions.checkArgument;
@@ -125,6 +129,10 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
     public static final String ILLEGAL_RANGE_QUERY_MESSAGE = "Range queries 
are not allowed for reads within a transaction; %s %s";
     public static final String UNSUPPORTED_MIGRATION = "Transaction Statement 
is unsupported when migrating away from Accord or before migration to Accord is 
complete for a range";
     public static final String NO_PARTITION_IN_CLAUSE_WITH_LIMIT = "Partition 
key is present in IN clause and there is a LIMIT... this is currently not 
supported; %s statement %s";
+    public static final String WRITE_TXN_EMPTY_WITH_IGNORED_READS = "Write txn 
produced no mutation, and its reads do not return to the caller; ignoring...";
+    public static final String WRITE_TXN_EMPTY_WITH_NO_READS = "Write txn 
produced no mutation, and had no reads; ignoring...";
+
+    private static NoSpamLogger noSpamLogger = 
NoSpamLogger.getLogger(LoggerFactory.getLogger(TransactionStatement.class), 1, 
TimeUnit.MINUTES);
 
     static class NamedSelect
     {
@@ -350,9 +358,8 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
         int idx = 0;
         for (ModificationStatement modification : updates)
         {
-            TxnWrite.Fragment fragment = modification.getTxnWriteFragment(idx, 
state, options, keyCollector);
-            minEpoch = Math.max(minEpoch, 
fragment.baseUpdate.metadata().epoch.getEpoch());
-            fragments.add(fragment);
+            minEpoch = Math.max(minEpoch, 
modification.metadata().epoch.getEpoch());
+            fragments.addAll(modification.getTxnWriteFragment(idx, state, 
options, keyCollector));
 
             if 
(modification.allReferenceOperations().stream().anyMatch(ReferenceOperation::requiresRead))
             {
@@ -447,6 +454,7 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
     }
 
     @VisibleForTesting
+    @Nullable
     public Txn createTxn(ClientState state, QueryOptions options)
     {
         ClusterMetadata cm = ClusterMetadata.current();
@@ -467,8 +475,15 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
         {
             Int2ObjectHashMap<NamedSelect> autoReads = new 
Int2ObjectHashMap<>();
             List<TxnWrite.Fragment> writeFragments = 
createWriteFragments(state, options, autoReads, keyCollector);
-            ConsistencyLevel commitCL = consistencyLevelForAccordCommit(cm, 
tables, keyCollector, options.getConsistency());
             List<TxnNamedRead> reads = createNamedReads(options, autoReads, 
keyCollector);
+            if (writeFragments.isEmpty()) // ModificationStatement yield no 
Mutation (DELETE WHERE pk=0 AND c < 0 AND c > 0 -- matches no keys; so has no 
mutation)
+            {
+                // cleanup memory
+                keyCollector.clear();
+                autoReads.clear();
+                return maybeCreateTxnFromEmptyWrites(cm, options, tables);
+            }
+            ConsistencyLevel commitCL = consistencyLevelForAccordCommit(cm, 
tables, keyCollector, options.getConsistency());
             Keys keys = keyCollector.build();
             AccordUpdate update = new TxnUpdate(tables, writeFragments, 
createCondition(options), commitCL, PreserveTimestamp.no);
             TxnRead read = createTxnRead(tables, reads, null, Domain.Key);
@@ -476,6 +491,31 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
         }
     }
 
+    @Nullable
+    private Txn.InMemory maybeCreateTxnFromEmptyWrites(ClusterMetadata cm, 
QueryOptions options, TableMetadatas.Complete tables)
+    {
+        TableMetadatasAndKeys.KeyCollector keyCollector = new 
TableMetadatasAndKeys.KeyCollector(tables);
+        List<TxnNamedRead> reads = createNamedReads(options, null, 
keyCollector);
+        if (reads.isEmpty())
+        {
+            // no reads, this is a no-op
+            noSpamLogger.info(WRITE_TXN_EMPTY_WITH_NO_READS);
+            return null;
+        }
+        if (returningSelect == null && returningReferences == null)
+        {
+            // the reads were for the mutation, and since the mutation doesn't 
exist the reads are not needed
+            noSpamLogger.info(WRITE_TXN_EMPTY_WITH_IGNORED_READS);
+            return null;
+        }
+
+        // Return a read only txn
+        Keys keys = keyCollector.build();
+        TxnRead read = createTxnRead(tables, reads, 
consistencyLevelForAccordRead(cm, tables, keys, 
options.getSerialConsistency()), Domain.Key);
+        Txn.Kind kind = shouldReadEphemerally(keys, 
tables.getMetadata((TableId)keys.get(0).prefix()).params, Read);
+        return new Txn.InMemory(kind, keys, read, TxnQuery.ALL, null, new 
TableMetadatasAndKeys(tables, keys));
+    }
+
     /**
      * Returns {@code true} only if the statement selects multiple clusterings 
in a partition
      */
@@ -514,6 +554,8 @@ public class TransactionStatement implements 
CQLStatement.CompositeCQLStatement,
         }
 
         Txn txn = createTxn(state.getClientState(), options);
+        if (txn == null)
+            return new ResultMessage.Void();
 
         TxnResult txnResult = AccordService.instance().coordinate(minEpoch, 
txn, options.getConsistency(), requestTime);
         if (txnResult.kind() == retry_new_protocol)
diff --git 
a/src/java/org/apache/cassandra/service/accord/serializers/AbstractSortedCollector.java
 
b/src/java/org/apache/cassandra/service/accord/serializers/AbstractSortedCollector.java
index d979dab3fa..80672054ed 100644
--- 
a/src/java/org/apache/cassandra/service/accord/serializers/AbstractSortedCollector.java
+++ 
b/src/java/org/apache/cassandra/service/accord/serializers/AbstractSortedCollector.java
@@ -97,6 +97,14 @@ public abstract class AbstractSortedCollector<T, C> extends 
AbstractList<T>
         return add;
     }
 
+    public void clear()
+    {
+        if (count > 1)
+            cachedAny().forceDiscard((Object[])buffer, count);
+        buffer = null;
+        count = 0;
+    }
+
     public C build()
     {
         if (count == 0)
diff --git 
a/src/java/org/apache/cassandra/service/consensus/TransactionalMode.java 
b/src/java/org/apache/cassandra/service/consensus/TransactionalMode.java
index 5355d33d10..b2336522c6 100644
--- a/src/java/org/apache/cassandra/service/consensus/TransactionalMode.java
+++ b/src/java/org/apache/cassandra/service/consensus/TransactionalMode.java
@@ -18,6 +18,8 @@
 
 package org.apache.cassandra.service.consensus;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import org.apache.cassandra.db.ConsistencyLevel;
 import org.apache.cassandra.db.PartitionPosition;
 import org.apache.cassandra.dht.AbstractBounds;
@@ -289,6 +291,12 @@ public enum TransactionalMode
         return valueOf(toLowerCaseLocalized(name));
     }
 
+    @VisibleForTesting
+    public static TransactionalMode[] supported()
+    {
+        return new TransactionalMode[]{ TransactionalMode.off, 
TransactionalMode.mixed_reads, TransactionalMode.full };
+    }
+
     public boolean isTestMode()
     {
         return name().startsWith("test_");
diff --git 
a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java
 
b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java
index f89fb1a986..da972c9432 100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java
@@ -34,6 +34,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -78,6 +79,7 @@ import org.apache.cassandra.service.accord.AccordService;
 import org.apache.cassandra.service.accord.AccordTestUtils;
 import org.apache.cassandra.service.consensus.TransactionalMode;
 import 
org.apache.cassandra.service.consensus.migration.TransactionalMigrationFromMode;
+import org.apache.cassandra.utils.AssertionUtils;
 import org.apache.cassandra.utils.ByteBufferUtil;
 import org.apache.cassandra.utils.FailingConsumer;
 import org.apache.cassandra.utils.Pair;
@@ -3349,4 +3351,65 @@ public abstract class AccordCQLTestBase extends 
AccordTestBase
                            .isEmpty();
         });
     }
+
+    @Test
+    public void emptyModification() throws Exception
+    {
+        test("CREATE TABLE " + qualifiedAccordTableName + " (k int, s int 
static, c int, v int, PRIMARY KEY (k, c)) WITH " + 
transactionalMode.asCqlParam(), cluster -> {
+            String deleteStmt = "DELETE FROM " + qualifiedAccordTableName + " 
WHERE k=0 AND c < 0 AND c > 0";
+            String selectStmt = "SELECT * FROM " + qualifiedAccordTableName + 
" WHERE k=0";
+            ICoordinator node = cluster.coordinator(1);
+            node.execute("INSERT INTO " + qualifiedAccordTableName + " (k, s, 
c, v) VALUES (0, 0, 0, 0)", QUORUM);
+
+            // CAS rejects
+            Assertions.assertThatThrownBy(() -> node.execute(deleteStmt + " IF 
s=0", QUORUM))
+                      
.is(AssertionUtils.isInstanceof(InvalidRequestException.class))
+                      .hasMessageContaining("DELETE statements must restrict 
all PRIMARY KEY columns with equality relations");
+
+            // BEGIN TRANSACTION does not!  This should no-op (user has no way 
to know it did)
+            node.execute(wrapInTxn(deleteStmt), QUORUM);
+            
Assertions.assertThat(node.instance().logs().watchFor(TransactionStatement.WRITE_TXN_EMPTY_WITH_NO_READS).getResult()).isNotEmpty();
+
+            // if there was a read, the txn was downgraded to a read txn
+            var results = node.execute(wrapInTxn(selectStmt, deleteStmt), 
QUORUM);
+            Assertions.assertThat(results).hasDimensions(1, 4);
+
+            // there are lets but no returning
+            node.execute(wrapInTxn("LET a = (" + selectStmt + " LIMIT 1" + 
')', deleteStmt), QUORUM);
+            
Assertions.assertThat(node.instance().logs().watchFor(TransactionStatement.WRITE_TXN_EMPTY_WITH_IGNORED_READS).getResult()).isNotEmpty();
+        });
+    }
+
+    @Test
+    public void multiPartitionUpdate() throws Exception
+    {
+        test("CREATE TABLE " + qualifiedAccordTableName + "(k int PRIMARY KEY, 
v int) WITH " + transactionalMode.asCqlParam(), cluster -> {
+            var node = cluster.coordinator(1);
+            int numPartitions = 10;
+            for (int i = 0; i < numPartitions; i++)
+                node.execute("INSERT INTO " + qualifiedAccordTableName + "(k, 
v) VALUES (?, ?)", QUORUM, i, 0);
+
+            Object[] binds = IntStream.range(0, 
numPartitions).boxed().toArray();
+            String where = "WHERE k IN (" + IntStream.range(0, 
numPartitions).mapToObj(i -> "?").collect(Collectors.joining(", ")) + ')';
+            String updateCQL = "UPDATE " + qualifiedAccordTableName + " SET 
v=1 " + where;
+            String deleteCQL = "DELETE FROM " + qualifiedAccordTableName + ' ' 
+ where;
+
+            // update multiple partitions at once
+            node.execute(wrapInTxn(updateCQL), QUORUM, binds);
+            for (int i = 0; i < numPartitions; i++)
+            {
+                var qr = node.executeWithResult("SELECT v FROM " + 
qualifiedAccordTableName + " WHERE k=?", QUORUM, i);
+                
QueryResultUtil.assertThat(qr).isEqualTo(QueryResults.builder().row(1).build());
+            }
+
+            // now delete
+            node.execute(wrapInTxn(deleteCQL), QUORUM, binds);
+
+            for (int i = 0; i < numPartitions; i++)
+            {
+                var qr = node.executeWithResult("SELECT v FROM " + 
qualifiedAccordTableName + " WHERE k=?", QUORUM, i);
+                QueryResultUtil.assertThat(qr).isEmpty();
+            }
+        });
+    }
 }
diff --git 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/AccordInteropMultiNodeTableWalkBase.java
 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/AccordInteropMultiNodeTableWalkBase.java
index 7d54906b7d..d1a5e0f32e 100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/AccordInteropMultiNodeTableWalkBase.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/AccordInteropMultiNodeTableWalkBase.java
@@ -18,9 +18,13 @@
 
 package org.apache.cassandra.distributed.test.cql3;
 
+import javax.annotation.Nullable;
+
 import accord.utils.Property;
 import accord.utils.RandomSource;
 import org.apache.cassandra.cql3.KnownIssue;
+import org.apache.cassandra.cql3.ast.Mutation;
+import org.apache.cassandra.cql3.ast.Txn;
 import org.apache.cassandra.distributed.Cluster;
 import org.apache.cassandra.distributed.api.ConsistencyLevel;
 import org.apache.cassandra.distributed.shared.ClusterUtils;
@@ -86,11 +90,14 @@ Suppressed: java.lang.AssertionError: Unknown keyspace ks12
     public class AccordInteropMultiNodeState extends MultiNodeState
     {
         private final boolean allowUsingTimestamp;
+        private final float wrapMutationAsTxn;
 
         public AccordInteropMultiNodeState(RandomSource rs, Cluster cluster)
         {
             super(rs, cluster);
             allowUsingTimestamp = rs.nextBoolean();
+            // when USING TIMESTAMP is done for the mutation, BEGIN 
TRANSACTION can't be supported as it doesn't allow that syntax; so need to 
disable wrapping mutations
+            wrapMutationAsTxn = allowUsingTimestamp ? 0F : rs.nextFloat();
         }
 
         @Override
@@ -102,6 +109,14 @@ Suppressed: java.lang.AssertionError: Unknown keyspace ks12
             ClusterUtils.awaitAccordEpochReady(cluster, maxEpoch.getEpoch());
         }
 
+        @Override
+        protected <S extends BaseState> Property.Command<S, Void, ?> 
command(RandomSource rs, Mutation mutation, @Nullable String annotate)
+        {
+            if (wrapMutationAsTxn != 0 && rs.decide(wrapMutationAsTxn))
+                return super.command(rs, Txn.wrap(mutation), annotate);
+            return super.command(rs, mutation, annotate);
+        }
+
         @Override
         protected boolean allowUsingTimestamp()
         {
diff --git 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/SingleNodeTableWalkTest.java
 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/SingleNodeTableWalkTest.java
index a890efed4c..a2a526121b 100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/SingleNodeTableWalkTest.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/SingleNodeTableWalkTest.java
@@ -444,28 +444,33 @@ public class SingleNodeTableWalkTest extends 
StatefulASTBase
 
             cluster.forEach(i -> i.nodetoolResult("disableautocompaction", 
metadata.keyspace, this.metadata.name).asserts().success());
 
-            List<LinkedHashMap<Symbol, Object>> uniquePartitions;
+            ASTGenerators.MutationGenBuilder mutationGenBuilder = new 
ASTGenerators.MutationGenBuilder(metadata)
+                                                                  
.withTxnSafe()
+                                                                  
.withColumnExpressions(e -> 
e.withOperators(Generators.fromGen(BOOLEAN_DISTRIBUTION.next(rs))))
+                                                                  
.withIgnoreIssues(IGNORED_ISSUES);
+
+            // Run the test with and without bound partitions
+            // When using fixed partitions, each mutation will be for a single 
partition and will use pk=? syntax
+            // When using unbounded partitions then IN clause is used on 
partition keys, leading to mutations touching multiple partitions
+            if (rs.nextBoolean())
             {
-                int unique = rs.nextInt(1, 10);
-                List<Symbol> columns = model.factory.partitionColumns;
-                List<Gen<?>> gens = new ArrayList<>(columns.size());
-                for (int i = 0; i < columns.size(); i++)
-                    
gens.add(toGen(getTypeSupport(columns.get(i).type()).valueGen));
-                uniquePartitions = Gens.lists(r2 -> {
-                    LinkedHashMap<Symbol, Object> vs = new LinkedHashMap<>();
+                List<LinkedHashMap<Symbol, Object>> uniquePartitions;
+                {
+                    int unique = rs.nextInt(1, 10);
+                    List<Symbol> columns = model.factory.partitionColumns;
+                    List<Gen<?>> gens = new ArrayList<>(columns.size());
                     for (int i = 0; i < columns.size(); i++)
-                        vs.put(columns.get(i), gens.get(i).next(r2));
-                    return vs;
-                }).uniqueBestEffort().ofSize(unique).next(rs);
+                        
gens.add(toGen(getTypeSupport(columns.get(i).type()).valueGen));
+                    uniquePartitions = Gens.lists(r2 -> {
+                        LinkedHashMap<Symbol, Object> vs = new 
LinkedHashMap<>();
+                        for (int i = 0; i < columns.size(); i++)
+                            vs.put(columns.get(i), gens.get(i).next(r2));
+                        return vs;
+                    }).uniqueBestEffort().ofSize(unique).next(rs);
+                }
+                
mutationGenBuilder.withPartitions(Generators.fromGen(Gens.mixedDistribution(uniquePartitions).next(rs)));
             }
 
-            ASTGenerators.MutationGenBuilder mutationGenBuilder = new 
ASTGenerators.MutationGenBuilder(metadata)
-                                                                  
.withoutTransaction()
-                                                                  .withoutTtl()
-                                                                  
.withoutTimestamp()
-                                                                  
.withPartitions(Generators.fromGen(Gens.mixedDistribution(uniquePartitions).next(rs)))
-                                                                  
.withColumnExpressions(e -> 
e.withOperators(Generators.fromGen(BOOLEAN_DISTRIBUTION.next(rs))))
-                                                                  
.withIgnoreIssues(IGNORED_ISSUES);
             if (IGNORED_ISSUES.contains(KnownIssue.SAI_EMPTY_TYPE))
             {
                 model.factory.regularAndStaticColumns.stream()
diff --git 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/StatefulASTBase.java
 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/StatefulASTBase.java
index 19f95f4e5a..756095673e 100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/cql3/StatefulASTBase.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/cql3/StatefulASTBase.java
@@ -19,9 +19,7 @@
 package org.apache.cassandra.distributed.test.cql3;
 
 import java.io.IOException;
-import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.EnumSet;
 import java.util.List;
@@ -33,23 +31,15 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import javax.annotation.Nullable;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Maps;
 import org.slf4j.Logger;
 
 import accord.utils.Gen;
 import accord.utils.Gens;
 import accord.utils.Property;
 import accord.utils.RandomSource;
-import com.datastax.driver.core.ColumnDefinitions;
-import com.datastax.driver.core.ResultSet;
-import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
-import com.datastax.driver.core.SimpleStatement;
 import com.datastax.driver.core.SocketOptions;
-import com.datastax.driver.core.exceptions.ReadFailureException;
-import com.datastax.driver.core.exceptions.WriteFailureException;
 import org.apache.cassandra.config.CassandraRelevantProperties;
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.cql3.KnownIssue;
@@ -63,6 +53,7 @@ import org.apache.cassandra.cql3.ast.Select;
 import org.apache.cassandra.cql3.ast.StandardVisitors;
 import org.apache.cassandra.cql3.ast.Statement;
 import org.apache.cassandra.cql3.ast.TableReference;
+import org.apache.cassandra.cql3.ast.Txn;
 import org.apache.cassandra.cql3.ast.Value;
 import org.apache.cassandra.cql3.ast.Visitor;
 import org.apache.cassandra.cql3.ast.Visitor.CompositeVisitor;
@@ -81,7 +72,7 @@ import org.apache.cassandra.distributed.api.IInstanceConfig;
 import org.apache.cassandra.distributed.api.IInvokableInstance;
 import org.apache.cassandra.distributed.test.JavaDriverUtils;
 import org.apache.cassandra.distributed.test.TestBaseImpl;
-import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.distributed.util.DriverUtils;
 import org.apache.cassandra.harry.model.ASTSingleTableModel;
 import org.apache.cassandra.harry.util.StringUtils;
 import org.apache.cassandra.repair.RepairGenerators;
@@ -97,7 +88,6 @@ import org.quicktheories.generators.SourceDSL;
 
 import static accord.utils.Property.ignoreCommand;
 import static accord.utils.Property.multistep;
-import static org.apache.cassandra.distributed.test.JavaDriverUtils.toDriverCL;
 import static 
org.apache.cassandra.utils.AbstractTypeGenerators.overridePrimitiveTypeSupport;
 import static 
org.apache.cassandra.utils.AbstractTypeGenerators.stringComparator;
 
@@ -612,6 +602,29 @@ public class StatefulASTBase extends TestBaseImpl
             });
         }
 
+        protected <S extends BaseState> Property.Command<S, Void, ?> 
command(RandomSource rs, Txn txn)
+        {
+            return command(rs, txn, null);
+        }
+
+        protected <S extends BaseState> Property.Command<S, Void, ?> 
command(RandomSource rs, Txn txn, @Nullable String annotate)
+        {
+            var inst = selectInstance(rs);
+            String postfix = "on " + inst;
+            if (model.isConditional(txn))
+                postfix += ", would apply " + model.shouldApply(txn);
+            if (annotate == null) annotate = postfix;
+            else annotate += ", " + postfix;
+
+            return new Property.SimpleCommand<>(humanReadable(txn, annotate), 
s -> {
+                boolean hasMutation = txn.ifBlock.isPresent() || 
!txn.mutations.isEmpty();
+                ConsistencyLevel cl = hasMutation ? s.mutationCl() : 
s.selectCl();
+                s.model.updateAndValidate(s.executeQuery(inst, 
Integer.MAX_VALUE, cl, txn), txn);
+                if (hasMutation)
+                    s.mutation();
+            });
+        }
+
         protected IInvokableInstance selectInstance(RandomSource rs)
         {
             return cluster.get(rs.nextInt(0, cluster.size()) + 1);
@@ -686,82 +699,7 @@ public class StatefulASTBase extends TestBaseImpl
                 instance.executeInternal(stmt.toCQL(), (Object[]) 
stmt.bindsEncoded());
                 return new ByteBuffer[0][];
             }
-            else
-            {
-                SimpleStatement ss = new SimpleStatement(stmt.toCQL(), 
(Object[]) stmt.bindsEncoded());
-                if (fetchSize != Integer.MAX_VALUE)
-                    ss.setFetchSize(fetchSize);
-                if (stmt.kind() == Statement.Kind.MUTATION)
-                {
-                    switch (cl)
-                    {
-                        case SERIAL:
-                            ss.setSerialConsistencyLevel(toDriverCL(cl));
-                            
ss.setConsistencyLevel(com.datastax.driver.core.ConsistencyLevel.QUORUM);
-                            break;
-                        case LOCAL_SERIAL:
-                            ss.setSerialConsistencyLevel(toDriverCL(cl));
-                            
ss.setConsistencyLevel(com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM);
-                            break;
-                        default:
-                            ss.setConsistencyLevel(toDriverCL(cl));
-                    }
-                }
-                else
-                {
-                    ss.setConsistencyLevel(toDriverCL(cl));
-                }
-
-                InetSocketAddress broadcastAddress = 
instance.config().broadcastAddress();
-                var host = client.getMetadata().getAllHosts().stream()
-                                 .filter(h -> 
h.getBroadcastSocketAddress().getAddress().equals(broadcastAddress.getAddress()))
-                                 .filter(h -> 
h.getBroadcastSocketAddress().getPort() == broadcastAddress.getPort())
-                                 .findAny()
-                                 .get();
-                ss.setHost(host);
-                ResultSet result;
-                try
-                {
-                    result = session.execute(ss);
-                }
-                catch (ReadFailureException t)
-                {
-                    throw new AssertionError("failed from=" + 
Maps.transformValues(t.getFailuresMap(), BaseState::safeErrorCode), t);
-                }
-                catch (WriteFailureException t)
-                {
-                    throw new AssertionError("failed from=" + 
Maps.transformValues(t.getFailuresMap(), BaseState::safeErrorCode), t);
-                }
-                return getRowsAsByteBuffer(result);
-            }
-        }
-
-        private static String safeErrorCode(Integer code)
-        {
-            try
-            {
-                return RequestFailureReason.fromCode(code).name();
-            }
-            catch (IllegalArgumentException e)
-            {
-                return "Unexpected code " + code + ": " + e.getMessage();
-            }
-        }
-
-        @VisibleForTesting
-        static ByteBuffer[][] getRowsAsByteBuffer(ResultSet result)
-        {
-            ColumnDefinitions columns = result.getColumnDefinitions();
-            List<ByteBuffer[]> ret = new ArrayList<>();
-            for (Row rowVal : result)
-            {
-                ByteBuffer[] row = new ByteBuffer[columns.size()];
-                for (int i = 0; i < columns.size(); i++)
-                    row[i] = rowVal.getBytesUnsafe(i);
-                ret.add(row);
-            }
-            ByteBuffer[][] a = new ByteBuffer[ret.size()][];
-            return ret.toArray(a);
+            return DriverUtils.executeQuery(session, instance, fetchSize, cl, 
stmt);
         }
 
         protected String humanReadable(Statement stmt, @Nullable String 
annotate)
diff --git 
a/test/distributed/org/apache/cassandra/distributed/util/DriverUtils.java 
b/test/distributed/org/apache/cassandra/distributed/util/DriverUtils.java
new file mode 100644
index 0000000000..46ed21ee9a
--- /dev/null
+++ b/test/distributed/org/apache/cassandra/distributed/util/DriverUtils.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.util;
+
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import com.google.common.collect.Maps;
+
+import com.datastax.driver.core.ColumnDefinitions;
+import com.datastax.driver.core.Host;
+import com.datastax.driver.core.ResultSet;
+import com.datastax.driver.core.Row;
+import com.datastax.driver.core.Session;
+import com.datastax.driver.core.SimpleStatement;
+import com.datastax.driver.core.exceptions.ReadFailureException;
+import com.datastax.driver.core.exceptions.WriteFailureException;
+import org.apache.cassandra.cql3.ast.Statement;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.api.IInstance;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+
+import static org.apache.cassandra.distributed.test.JavaDriverUtils.toDriverCL;
+
+public class DriverUtils
+{
+    public static ByteBuffer[][] getRowsAsByteBuffer(ResultSet result)
+    {
+        ColumnDefinitions columns = result.getColumnDefinitions();
+        List<ByteBuffer[]> ret = new ArrayList<>();
+        for (Row rowVal : result)
+        {
+            ByteBuffer[] row = new ByteBuffer[columns.size()];
+            for (int i = 0; i < columns.size(); i++)
+                row[i] = rowVal.getBytesUnsafe(i);
+            ret.add(row);
+        }
+        ByteBuffer[][] a = new ByteBuffer[ret.size()][];
+        return ret.toArray(a);
+    }
+
+    public static ByteBuffer[][] executeQuery(Session session,
+                                              IInstance instance,
+                                              int fetchSize,
+                                              ConsistencyLevel cl,
+                                              Statement stmt)
+    {
+        SimpleStatement ss = new SimpleStatement(stmt.toCQL(), (Object[]) 
stmt.bindsEncoded());
+        if (fetchSize != Integer.MAX_VALUE)
+            ss.setFetchSize(fetchSize);
+        if (stmt.kind() == Statement.Kind.MUTATION)
+        {
+            switch (cl)
+            {
+                case SERIAL:
+                    ss.setSerialConsistencyLevel(toDriverCL(cl));
+                    
ss.setConsistencyLevel(com.datastax.driver.core.ConsistencyLevel.QUORUM);
+                    break;
+                case LOCAL_SERIAL:
+                    ss.setSerialConsistencyLevel(toDriverCL(cl));
+                    
ss.setConsistencyLevel(com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM);
+                    break;
+                default:
+                    ss.setConsistencyLevel(toDriverCL(cl));
+            }
+        }
+        else
+        {
+            ss.setConsistencyLevel(toDriverCL(cl));
+        }
+
+        var host = getHost(session, instance);
+        ss.setHost(host);
+        ResultSet result;
+        try
+        {
+            result = session.execute(ss);
+        }
+        catch (ReadFailureException t)
+        {
+            throw new AssertionError("failed from=" + 
Maps.transformValues(t.getFailuresMap(), DriverUtils::safeErrorCode), t);
+        }
+        catch (WriteFailureException t)
+        {
+            throw new AssertionError("failed from=" + 
Maps.transformValues(t.getFailuresMap(), DriverUtils::safeErrorCode), t);
+        }
+        return getRowsAsByteBuffer(result);
+    }
+
+    private static Host getHost(Session session, IInstance instance)
+    {
+        InetSocketAddress broadcastAddress = 
instance.config().broadcastAddress();
+        return session.getCluster().getMetadata().getAllHosts().stream()
+                      .filter(h -> 
h.getBroadcastSocketAddress().getAddress().equals(broadcastAddress.getAddress()))
+                      .filter(h -> h.getBroadcastSocketAddress().getPort() == 
broadcastAddress.getPort())
+                      .findAny()
+                      .get();
+    }
+
+    private static String safeErrorCode(Integer code)
+    {
+        try
+        {
+            return RequestFailureReason.fromCode(code).name();
+        }
+        catch (IllegalArgumentException e)
+        {
+            return "Unexpected code " + code + ": " + e.getMessage();
+        }
+    }
+}
diff --git 
a/test/distributed/org/apache/cassandra/fuzz/topology/AccordTopologyMixupTest.java
 
b/test/distributed/org/apache/cassandra/fuzz/topology/AccordTopologyMixupTest.java
index 8895a7271b..d75307514d 100644
--- 
a/test/distributed/org/apache/cassandra/fuzz/topology/AccordTopologyMixupTest.java
+++ 
b/test/distributed/org/apache/cassandra/fuzz/topology/AccordTopologyMixupTest.java
@@ -105,7 +105,7 @@ public class AccordTopologyMixupTest extends 
TopologyMixupTestBase<AccordTopolog
         overridePrimitiveTypeSupport(BytesType.instance, 
AbstractTypeGenerators.TypeSupport.of(BytesType.instance, Generators.bytes(1, 
10), FastByteOperations::compareUnsigned));
     }
 
-    private static final List<TransactionalMode> TRANSACTIONAL_MODES = 
Stream.of(TransactionalMode.values()).filter(t -> 
t.accordIsEnabled).collect(Collectors.toList());
+    private static final List<TransactionalMode> TRANSACTIONAL_MODES = 
Stream.of(TransactionalMode.supported()).filter(t -> 
t.accordIsEnabled).collect(Collectors.toList());
 
     @Override
     protected Gen<State<Spec>> stateGen()
diff --git 
a/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModel.java 
b/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModel.java
index c914350328..37f4585332 100644
--- a/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModel.java
+++ b/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModel.java
@@ -596,7 +596,7 @@ public class ASTSingleTableModel
             }
             // table has clustering but non are in the write, so only 
pk/static can be updated
             if (!factory.clusteringColumns.isEmpty() && remaining.isEmpty())
-                return;
+                continue;
             BytesPartitionState finalPartition = partition;
             for (Clustering<ByteBuffer> cd : clustering(remaining))
             {
@@ -621,7 +621,7 @@ public class ASTSingleTableModel
         for (Clustering<ByteBuffer> pd : pks)
         {
             BytesPartitionState partition = 
partitions.get(factory.createRef(pd));
-            if (partition == null) return; // can't delete a partition that 
doesn't exist...
+            if (partition == null) continue; // can't delete a partition that 
doesn't exist...
 
             DeleteKind kind = DeleteKind.PARTITION;
             if (!delete.columns.isEmpty())
diff --git 
a/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModelTest.java 
b/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModelTest.java
index 3bdf742ce8..fbd9e07832 100644
--- 
a/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModelTest.java
+++ 
b/test/harry/main/org/apache/cassandra/harry/model/ASTSingleTableModelTest.java
@@ -78,6 +78,36 @@ public class ASTSingleTableModelTest
     public static final SetType<Integer> SET_INT = 
SetType.getInstance(Int32Type.instance, true);
     public static final MapType<Integer, Integer> MAP_INT = 
MapType.getInstance(Int32Type.instance, Int32Type.instance, true);
 
+    @Test
+    public void multiplePartitionUpdate()
+    {
+        TableMetadata metadata = new Builder()
+                                 .pk(1)
+                                 .ck(1)
+                                 .statics(1)
+                                 .regular(1)
+                                 .build();
+        ASTSingleTableModel model = new ASTSingleTableModel(metadata);
+        /*
+         UPDATE ks1.tbl USING TIMESTAMP 44
+         SET s0=[{00000000-0000-4100-b000-000000000000: -1, 
00000000-0000-4900-9500-000000000000: -128, 
00000000-0000-4b00-8700-000000000000: 115}, 
{00000000-0000-4200-ab00-000000000000: -115, 
00000000-0000-4200-b000-000000000000: -3, 00000000-0000-4600-b400-000000000000: 
66}]
+         WHERE  pk0 IN (70, 47, -35) -- on node1
+         */
+
+        model.update(Mutation.update(metadata)
+                             .timestamp(44)
+                             .set("s", 42)
+                             .in("pk", Int32Type.instance, 70, 47, -35)
+                             .build());
+        ByteBuffer s = value(42);
+        ByteBuffer[][] expected = rows(
+        row(value(47), null, s, null),
+        row(value(-35), null, s, null),
+        row(value(70), null, s, null)
+        );
+        model.validate(expected, Select.builder(metadata).build());
+    }
+
     @Test
     public void singlePartition()
     {
@@ -863,6 +893,11 @@ public class ASTSingleTableModelTest
         return tables;
     }
 
+    private static ByteBuffer value(int num)
+    {
+        return Int32Type.instance.decompose(num);
+    }
+
     private static class ModelModel
     {
         private final ASTSingleTableModel model;
diff --git a/test/unit/org/apache/cassandra/utils/ASTGenerators.java 
b/test/unit/org/apache/cassandra/utils/ASTGenerators.java
index 860995b28f..9cd39abba7 100644
--- a/test/unit/org/apache/cassandra/utils/ASTGenerators.java
+++ b/test/unit/org/apache/cassandra/utils/ASTGenerators.java
@@ -824,7 +824,10 @@ public class ASTGenerators
                         if (deleteKind == DeleteKind.Row && 
clusteringColumns.isEmpty())
                             deleteKind = DeleteKind.Partition;
 
-                        values(rnd, columnExpressions, builder, 
partitionColumns, partitionValueGen);
+                        if (allowUpdateMultiplePartitionKeys)
+                            where(rnd, columnExpressions, builder, 
partitionColumns, partitionValueGen);
+                        else
+                            values(rnd, columnExpressions, builder, 
partitionColumns, partitionValueGen);
 
                         switch (deleteKind)
                         {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org


Reply via email to