hachikuji commented on a change in pull request #9553:
URL: https://github.com/apache/kafka/pull/9553#discussion_r547391686



##########
File path: raft/src/main/java/org/apache/kafka/raft/LeaderState.java
##########
@@ -287,4 +287,7 @@ public String name() {
         return "Leader";
     }
 
+    @Override
+    public void close() {}

Review comment:
       Maybe we could add a default no-op implementation to EpochState?

##########
File path: 
clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotResponse.java
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.requests;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.UnaryOperator;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.message.FetchSnapshotResponseData;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.protocol.Message;
+
+final public class FetchSnapshotResponse extends AbstractResponse {
+    public final FetchSnapshotResponseData data;
+
+    public FetchSnapshotResponse(FetchSnapshotResponseData data) {
+        super(ApiKeys.FETCH_SNAPSHOT);
+
+        this.data = data;
+    }
+
+    @Override
+    public Map<Errors, Integer> errorCounts() {
+        Map<Errors, Integer> errors = new HashMap<>();
+
+        Errors topLevelError = Errors.forCode(data.errorCode());
+        if (topLevelError != Errors.NONE) {
+            errors.put(topLevelError, 1);
+        }
+
+        for (FetchSnapshotResponseData.TopicSnapshot topicResponse : 
data.topics()) {
+            for (FetchSnapshotResponseData.PartitionSnapshot partitionResponse 
: topicResponse.partitions()) {
+                errors.compute(Errors.forCode(partitionResponse.errorCode()),
+                    (error, count) -> count == null ? 1 : count + 1);
+            }
+        }
+
+        return errors;
+    }
+
+    @Override
+    public int throttleTimeMs() {
+        return data.throttleTimeMs();
+    }
+
+    @Override
+    protected Message data() {
+        return data;
+    }
+
+    /**
+     * Creates a FetchSnapshotResponseData with a top level error.
+     *
+     * @param error the top level error
+     * @return the created fetch snapshot response data
+     */
+    public static FetchSnapshotResponseData withTopError(Errors error) {

Review comment:
       nit: maybe `withTopLevelError`?

##########
File path: clients/src/main/resources/common/message/FetchSnapshotResponse.json
##########
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+{
+  "apiKey": 59,
+  "type": "response",
+  "name": "FetchSnapshotResponse",
+  "validVersions": "0",
+  "flexibleVersions": "0+",
+  "fields": [
+    { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", 
"ignorable": true,
+      "about": "The duration in milliseconds for which the request was 
throttled due to a quota violation, or zero if the request did not violate any 
quota." },
+    { "name": "ErrorCode", "type": "int16", "versions": "0+", "ignorable": 
false,
+      "about": "The top level response error code." },
+    { "name": "Topics", "type": "[]TopicSnapshot", "versions": "0+",
+      "about": "The topics to fetch.", "fields": [
+      { "name": "Name", "type": "string", "versions": "0+", "entityType": 
"topicName",
+        "about": "The name of the topic to fetch." },
+      { "name": "Partitions", "type": "[]PartitionSnapshot", "versions": "0+",
+        "about": "The partitions to fetch.", "fields": [
+        { "name": "Index", "type": "int32", "versions": "0+",
+          "about": "The partition index." },
+        { "name": "ErrorCode", "type": "int16", "versions": "0+",
+          "about": "The error code, or 0 if there was no fetch error." },
+        { "name": "SnapshotId", "type": "SnapshotId", "versions": "0+",
+          "about": "The snapshot endOffset and epoch fetched",
+          "fields": [
+          { "name": "EndOffset", "type": "int64", "versions": "0+" },
+          { "name": "Epoch", "type": "int32", "versions": "0+" }
+        ]},
+        { "name": "CurrentLeader", "type": "LeaderIdAndEpoch",
+          "versions": "0+", "taggedVersions": "0+", "tag": 0, "fields": [
+          { "name": "LeaderId", "type": "int32", "versions": "0+",
+            "about": "The ID of the current leader or -1 if the leader is 
unknown."},
+          { "name": "LeaderEpoch", "type": "int32", "versions": "0+",
+            "about": "The latest known leader epoch"}
+        ]},
+        { "name": "Size", "type": "int64", "versions": "0+",
+          "about": "The total size of the snapshot." },
+        { "name": "Position", "type": "int64", "versions": "0+",
+          "about": "The starting byte position within the snapshot included in 
the Bytes field." },
+        { "name": "Bytes", "type": "bytes", "versions": "0+", "zeroCopy": true,

Review comment:
       Can you remind me if we are planning to change the type to "records"? I 
don't think we will get the benefit of `sendfile` unless we do so.

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1101,6 +1140,174 @@ private DescribeQuorumResponseData 
handleDescribeQuorumRequest(
         );
     }
 
+    private FetchSnapshotResponseData handleFetchSnapshotRequest(
+        RaftRequest.Inbound requestMetadata
+    ) throws IOException {
+        FetchSnapshotRequestData data = (FetchSnapshotRequestData) 
requestMetadata.data;
+
+        if (data.topics().size() != 1 && 
data.topics().get(0).partitions().size() != 1) {
+            return FetchSnapshotResponse.withTopError(Errors.INVALID_REQUEST);
+        }
+
+        Optional<FetchSnapshotRequestData.PartitionSnapshot> 
partitionSnapshotOpt = FetchSnapshotRequest
+            .forTopicPartition(data, log.topicPartition());
+        if (!partitionSnapshotOpt.isPresent()) {
+            // The Raft client assumes that there is only one topic partition.
+            TopicPartition unknownTopicPartition = new TopicPartition(
+                data.topics().get(0).name(),
+                data.topics().get(0).partitions().get(0).partition()
+            );
+
+            return FetchSnapshotResponse.singleton(
+                unknownTopicPartition,
+                responsePartitionSnapshot -> {
+                    return responsePartitionSnapshot
+                        
.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code());
+                }
+            );
+        }
+
+        FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = 
partitionSnapshotOpt.get();
+        Optional<Errors> leaderValidation = validateLeaderOnlyRequest(
+                partitionSnapshot.currentLeaderEpoch()
+        );
+        if (leaderValidation.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(leaderValidation.get().code());
+                }
+            );
+        }
+
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(
+            partitionSnapshot.snapshotId().endOffset(),
+            partitionSnapshot.snapshotId().epoch()
+        );
+        Optional<RawSnapshotReader> snapshotOpt = log.readSnapshot(snapshotId);
+        if (!snapshotOpt.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(Errors.SNAPSHOT_NOT_FOUND.code());
+                }
+            );
+        }
+
+        try (RawSnapshotReader snapshot = snapshotOpt.get()) {
+            int maxSnapshotSize;
+            try {
+                maxSnapshotSize = Math.toIntExact(snapshot.sizeInBytes());
+            } catch (ArithmeticException e) {
+                maxSnapshotSize = Integer.MAX_VALUE;
+            }
+
+            // TODO: Make sure that we also limit based on the fetch max bytes 
configuration
+            ByteBuffer buffer = ByteBuffer.allocate(Math.min(data.maxBytes(), 
maxSnapshotSize));
+            snapshot.read(buffer, partitionSnapshot.position());
+            buffer.flip();
+
+            long snapshotSize = snapshot.sizeInBytes();
+
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    addQuorumLeader(responsePartitionSnapshot)
+                        .snapshotId()
+                        .setEndOffset(snapshotId.offset)
+                        .setEpoch(snapshotId.epoch);
+
+                    return responsePartitionSnapshot
+                        .setSize(snapshotSize)
+                        .setPosition(partitionSnapshot.position())
+                        .setBytes(buffer);
+                }
+            );
+        }
+    }
+
+    private boolean handleFetchSnapshotResponse(
+        RaftResponse.Inbound responseMetadata,
+        long currentTimeMs
+    ) throws IOException {
+        FetchSnapshotResponseData data = (FetchSnapshotResponseData) 
responseMetadata.data;
+        Errors topLevelError = Errors.forCode(data.errorCode());
+        if (topLevelError != Errors.NONE) {
+            // TODO: check what values this expression returns

Review comment:
       Address TODO?

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -941,6 +949,8 @@ private FetchResponseData tryCompleteFetchRequest(
     ) {
         Optional<Errors> errorOpt = 
validateLeaderOnlyRequest(request.currentLeaderEpoch());
         if (errorOpt.isPresent()) {
+            // TODO: The replica should return what information it knows about 
the current epoch and

Review comment:
       Isn't this already done by `buildFetchResponse`?

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1101,6 +1140,174 @@ private DescribeQuorumResponseData 
handleDescribeQuorumRequest(
         );
     }
 
+    private FetchSnapshotResponseData handleFetchSnapshotRequest(
+        RaftRequest.Inbound requestMetadata
+    ) throws IOException {
+        FetchSnapshotRequestData data = (FetchSnapshotRequestData) 
requestMetadata.data;
+
+        if (data.topics().size() != 1 && 
data.topics().get(0).partitions().size() != 1) {
+            return FetchSnapshotResponse.withTopError(Errors.INVALID_REQUEST);
+        }
+
+        Optional<FetchSnapshotRequestData.PartitionSnapshot> 
partitionSnapshotOpt = FetchSnapshotRequest
+            .forTopicPartition(data, log.topicPartition());
+        if (!partitionSnapshotOpt.isPresent()) {
+            // The Raft client assumes that there is only one topic partition.
+            TopicPartition unknownTopicPartition = new TopicPartition(
+                data.topics().get(0).name(),
+                data.topics().get(0).partitions().get(0).partition()
+            );
+
+            return FetchSnapshotResponse.singleton(
+                unknownTopicPartition,
+                responsePartitionSnapshot -> {
+                    return responsePartitionSnapshot
+                        
.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code());
+                }
+            );
+        }
+
+        FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = 
partitionSnapshotOpt.get();
+        Optional<Errors> leaderValidation = validateLeaderOnlyRequest(
+                partitionSnapshot.currentLeaderEpoch()
+        );
+        if (leaderValidation.isPresent()) {

Review comment:
       nit: I guess you could use `leaderValidation.ifPresent`

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1037,6 +1047,35 @@ private boolean handleFetchResponse(
                     logger.info("Truncated to offset {} from Fetch response 
from leader {}",
                         truncationOffset, quorum.leaderIdOrNil());
                 });
+            } else if (partitionResponse.snapshotId().epoch() >= 0 ||
+                       partitionResponse.snapshotId().endOffset() >= 0) {
+                // The leader is asking us to fetch a snapshot
+
+                if (partitionResponse.snapshotId().epoch() < 0) {
+                    throw new KafkaException(

Review comment:
       Hmm.. The leader has sent a bad response. I think logging an error and 
retrying might be better than crashing the followers.

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1101,6 +1140,174 @@ private DescribeQuorumResponseData 
handleDescribeQuorumRequest(
         );
     }
 
+    private FetchSnapshotResponseData handleFetchSnapshotRequest(
+        RaftRequest.Inbound requestMetadata
+    ) throws IOException {
+        FetchSnapshotRequestData data = (FetchSnapshotRequestData) 
requestMetadata.data;
+
+        if (data.topics().size() != 1 && 
data.topics().get(0).partitions().size() != 1) {
+            return FetchSnapshotResponse.withTopError(Errors.INVALID_REQUEST);
+        }
+
+        Optional<FetchSnapshotRequestData.PartitionSnapshot> 
partitionSnapshotOpt = FetchSnapshotRequest
+            .forTopicPartition(data, log.topicPartition());
+        if (!partitionSnapshotOpt.isPresent()) {
+            // The Raft client assumes that there is only one topic partition.
+            TopicPartition unknownTopicPartition = new TopicPartition(
+                data.topics().get(0).name(),
+                data.topics().get(0).partitions().get(0).partition()
+            );
+
+            return FetchSnapshotResponse.singleton(
+                unknownTopicPartition,
+                responsePartitionSnapshot -> {
+                    return responsePartitionSnapshot
+                        
.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code());
+                }
+            );
+        }
+
+        FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = 
partitionSnapshotOpt.get();
+        Optional<Errors> leaderValidation = validateLeaderOnlyRequest(
+                partitionSnapshot.currentLeaderEpoch()
+        );
+        if (leaderValidation.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(leaderValidation.get().code());
+                }
+            );
+        }
+
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(
+            partitionSnapshot.snapshotId().endOffset(),
+            partitionSnapshot.snapshotId().epoch()
+        );
+        Optional<RawSnapshotReader> snapshotOpt = log.readSnapshot(snapshotId);
+        if (!snapshotOpt.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(Errors.SNAPSHOT_NOT_FOUND.code());
+                }
+            );
+        }
+
+        try (RawSnapshotReader snapshot = snapshotOpt.get()) {
+            int maxSnapshotSize;
+            try {
+                maxSnapshotSize = Math.toIntExact(snapshot.sizeInBytes());
+            } catch (ArithmeticException e) {
+                maxSnapshotSize = Integer.MAX_VALUE;
+            }
+
+            // TODO: Make sure that we also limit based on the fetch max bytes 
configuration
+            ByteBuffer buffer = ByteBuffer.allocate(Math.min(data.maxBytes(), 
maxSnapshotSize));
+            snapshot.read(buffer, partitionSnapshot.position());

Review comment:
       Is it worth validating that `partitionSnapshot.position()` is 
non-negative?

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1629,16 +1872,29 @@ private long pollFollowerAsVoter(FollowerState state, 
long currentTimeMs) throws
             transitionToCandidate(currentTimeMs);
             return 0L;
         } else {
-            long backoffMs = maybeSendRequest(
-                currentTimeMs,
-                state.leaderId(),
-                this::buildFetchRequest
-            );
+            long backoffMs;
+            if (state.fetchingSnapshot().isPresent()) {

Review comment:
       Can we move this logic to `pollFollower` somehow?

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1101,6 +1140,174 @@ private DescribeQuorumResponseData 
handleDescribeQuorumRequest(
         );
     }
 
+    private FetchSnapshotResponseData handleFetchSnapshotRequest(

Review comment:
       Just checking my understanding. This patch adds the logic to respond to 
the snapshot id from a fetch response and to handle send/handle snapshots when 
needed. However, since it does not contain the logic to set the snapshot id in 
the fetch request handler, none of this logic will get exercised by the 
simulation test. Is that right?

##########
File path: raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
##########
@@ -1101,6 +1140,174 @@ private DescribeQuorumResponseData 
handleDescribeQuorumRequest(
         );
     }
 
+    private FetchSnapshotResponseData handleFetchSnapshotRequest(
+        RaftRequest.Inbound requestMetadata
+    ) throws IOException {
+        FetchSnapshotRequestData data = (FetchSnapshotRequestData) 
requestMetadata.data;
+
+        if (data.topics().size() != 1 && 
data.topics().get(0).partitions().size() != 1) {
+            return FetchSnapshotResponse.withTopError(Errors.INVALID_REQUEST);
+        }
+
+        Optional<FetchSnapshotRequestData.PartitionSnapshot> 
partitionSnapshotOpt = FetchSnapshotRequest
+            .forTopicPartition(data, log.topicPartition());
+        if (!partitionSnapshotOpt.isPresent()) {
+            // The Raft client assumes that there is only one topic partition.
+            TopicPartition unknownTopicPartition = new TopicPartition(
+                data.topics().get(0).name(),
+                data.topics().get(0).partitions().get(0).partition()
+            );
+
+            return FetchSnapshotResponse.singleton(
+                unknownTopicPartition,
+                responsePartitionSnapshot -> {
+                    return responsePartitionSnapshot
+                        
.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code());
+                }
+            );
+        }
+
+        FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = 
partitionSnapshotOpt.get();
+        Optional<Errors> leaderValidation = validateLeaderOnlyRequest(
+                partitionSnapshot.currentLeaderEpoch()
+        );
+        if (leaderValidation.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(leaderValidation.get().code());
+                }
+            );
+        }
+
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(
+            partitionSnapshot.snapshotId().endOffset(),
+            partitionSnapshot.snapshotId().epoch()
+        );
+        Optional<RawSnapshotReader> snapshotOpt = log.readSnapshot(snapshotId);
+        if (!snapshotOpt.isPresent()) {
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    return addQuorumLeader(responsePartitionSnapshot)
+                        .setErrorCode(Errors.SNAPSHOT_NOT_FOUND.code());
+                }
+            );
+        }
+
+        try (RawSnapshotReader snapshot = snapshotOpt.get()) {
+            int maxSnapshotSize;
+            try {
+                maxSnapshotSize = Math.toIntExact(snapshot.sizeInBytes());
+            } catch (ArithmeticException e) {
+                maxSnapshotSize = Integer.MAX_VALUE;
+            }
+
+            // TODO: Make sure that we also limit based on the fetch max bytes 
configuration
+            ByteBuffer buffer = ByteBuffer.allocate(Math.min(data.maxBytes(), 
maxSnapshotSize));
+            snapshot.read(buffer, partitionSnapshot.position());
+            buffer.flip();
+
+            long snapshotSize = snapshot.sizeInBytes();
+
+            return FetchSnapshotResponse.singleton(
+                log.topicPartition(),
+                responsePartitionSnapshot -> {
+                    addQuorumLeader(responsePartitionSnapshot)
+                        .snapshotId()
+                        .setEndOffset(snapshotId.offset)
+                        .setEpoch(snapshotId.epoch);
+
+                    return responsePartitionSnapshot
+                        .setSize(snapshotSize)
+                        .setPosition(partitionSnapshot.position())
+                        .setBytes(buffer);
+                }
+            );
+        }
+    }
+
+    private boolean handleFetchSnapshotResponse(
+        RaftResponse.Inbound responseMetadata,
+        long currentTimeMs
+    ) throws IOException {
+        FetchSnapshotResponseData data = (FetchSnapshotResponseData) 
responseMetadata.data;
+        Errors topLevelError = Errors.forCode(data.errorCode());
+        if (topLevelError != Errors.NONE) {
+            // TODO: check what values this expression returns
+            return handleTopLevelError(topLevelError, responseMetadata);
+        }
+
+        if (data.topics().size() != 1 && 
data.topics().get(0).partitions().size() != 1) {
+            return false;
+        }
+
+        Optional<FetchSnapshotResponseData.PartitionSnapshot> 
partitionSnapshotOpt = FetchSnapshotResponse
+            .forTopicPartition(data, log.topicPartition());
+        if (!partitionSnapshotOpt.isPresent()) {
+            return false;
+        }
+
+        FetchSnapshotResponseData.PartitionSnapshot partitionSnapshot = 
partitionSnapshotOpt.get();
+
+        FetchSnapshotResponseData.LeaderIdAndEpoch currentLeaderIdAndEpoch = 
partitionSnapshot.currentLeader();
+        OptionalInt responseLeaderId = 
optionalLeaderId(currentLeaderIdAndEpoch.leaderId());
+        int responseEpoch = currentLeaderIdAndEpoch.leaderEpoch();
+        Errors error = Errors.forCode(partitionSnapshot.errorCode());
+
+        Optional<Boolean> handled = maybeHandleCommonResponse(
+            error, responseLeaderId, responseEpoch, currentTimeMs);
+        if (handled.isPresent()) {
+            // TODO: check what values this expression returns
+            return handled.get();
+        }
+
+        FollowerState state = quorum.followerStateOrThrow();
+
+        if (Errors.forCode(partitionSnapshot.errorCode()) == 
Errors.SNAPSHOT_NOT_FOUND ||
+            partitionSnapshot.snapshotId().endOffset() < 0 ||
+            partitionSnapshot.snapshotId().epoch() < 0) {
+
+            /* The leader deleted the snapshot before the follower could 
download it. Start over by

Review comment:
       A log message would probably be helpful. It's probably worth doing one 
full pass over the logic here to see where we could add extra logging.

##########
File path: 
raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
##########
@@ -0,0 +1,765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.memory.MemoryPool;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.message.FetchSnapshotRequestData;
+import org.apache.kafka.common.message.FetchSnapshotResponseData;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.record.CompressionType;
+import org.apache.kafka.common.requests.FetchSnapshotRequest;
+import org.apache.kafka.common.requests.FetchSnapshotResponse;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.StringSerde;
+import org.apache.kafka.snapshot.RawSnapshotReader;
+import org.apache.kafka.snapshot.RawSnapshotWriter;
+import org.apache.kafka.snapshot.SnapshotWriter;
+import org.apache.kafka.snapshot.SnapshotWriterTest;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Disabled;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+final public class KafkaRaftClientSnapshotTest {
+    @Test
+    public void testMissingFetchSnapshotRequest() throws Exception {
+        int localId = 0;
+        int epoch = 2;
+        Set<Integer> voters = Utils.mkSet(localId, localId + 1);
+
+        RaftClientTestContext context = 
RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                new OffsetAndEpoch(0, 0),
+                Integer.MAX_VALUE,
+                0
+            )
+        );
+
+        context.client.poll();
+
+        FetchSnapshotResponseData.PartitionSnapshot response =  
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
+        assertEquals(Errors.SNAPSHOT_NOT_FOUND, 
Errors.forCode(response.errorCode()));
+    }
+
+    @Test
+    public void testUnknownFetchSnapshotRequest() throws Exception {
+        int localId = 0;
+        Set<Integer> voters = Utils.mkSet(localId, localId + 1);
+        int epoch = 2;
+        TopicPartition topicPartition = new TopicPartition("unknown", 0);
+
+        RaftClientTestContext context = 
RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                topicPartition,
+                epoch,
+                new OffsetAndEpoch(0, 0),
+                Integer.MAX_VALUE,
+                0
+            )
+        );
+
+        context.client.poll();
+
+        FetchSnapshotResponseData.PartitionSnapshot response =  
context.assertSentFetchSnapshotResponse(topicPartition).get();
+        assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, 
Errors.forCode(response.errorCode()));
+    }
+
+    @Test
+    public void testFetchSnapshotRequestAsLeader() throws Exception {
+        int localId = 0;
+        Set<Integer> voters = Utils.mkSet(localId, localId + 1);
+        int epoch = 2;
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0);
+        List<String> records = Arrays.asList("foo", "bar");
+
+        RaftClientTestContext context = 
RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+
+        try (SnapshotWriter<String> snapshot = 
context.client.createSnapshot(snapshotId)) {
+            snapshot.append(records);
+            snapshot.freeze();
+        }
+
+        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
+            context.deliverRequest(
+                fetchSnapshotRequest(
+                    context.metadataPartition,
+                    epoch,
+                    snapshotId,
+                    Integer.MAX_VALUE,
+                    0
+                )
+            );
+
+            context.client.poll();
+
+            FetchSnapshotResponseData.PartitionSnapshot response =  context
+                .assertSentFetchSnapshotResponse(context.metadataPartition)
+                .get();
+
+            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+            assertEquals(snapshot.sizeInBytes(), response.size());
+            assertEquals(0, response.position());
+            assertEquals(snapshot.sizeInBytes(), response.bytes().remaining());
+
+            ByteBuffer buffer = 
ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes()));
+            snapshot.read(buffer, 0);
+            buffer.flip();
+
+            assertEquals(buffer.slice(), response.bytes());
+        }
+    }
+
+    @Test
+    public void testPartialFetchSnapshotRequestAsLeader() throws Exception {
+        int localId = 0;
+        Set<Integer> voters = Utils.mkSet(localId, localId + 1);
+        int epoch = 2;
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0);
+        List<String> records = Arrays.asList("foo", "bar");
+
+        RaftClientTestContext context = 
RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+
+        try (SnapshotWriter<String> snapshot = 
context.client.createSnapshot(snapshotId)) {
+            snapshot.append(records);
+            snapshot.freeze();
+        }
+
+        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
+            // Fetch half of the snapshot
+            context.deliverRequest(
+                fetchSnapshotRequest(
+                    context.metadataPartition,
+                    epoch,
+                    snapshotId,
+                    Math.toIntExact(snapshot.sizeInBytes() / 2),
+                    0
+                )
+            );
+
+            context.client.poll();
+
+            FetchSnapshotResponseData.PartitionSnapshot response = context
+                .assertSentFetchSnapshotResponse(context.metadataPartition)
+                .get();
+
+            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+            assertEquals(snapshot.sizeInBytes(), response.size());
+            assertEquals(0, response.position());
+            assertEquals(snapshot.sizeInBytes() / 2, 
response.bytes().remaining());
+
+            ByteBuffer snapshotBuffer = 
ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes()));
+            snapshot.read(snapshotBuffer, 0);
+            snapshotBuffer.flip();
+
+            ByteBuffer responseBuffer = 
ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes()));
+            responseBuffer.put(response.bytes());
+
+            ByteBuffer expectedBytes = snapshotBuffer.duplicate();
+            expectedBytes.limit(Math.toIntExact(snapshot.sizeInBytes() / 2));
+
+            assertEquals(expectedBytes, responseBuffer.duplicate().flip());
+
+            // Fetch the remainder of the snapshot
+            context.deliverRequest(
+                fetchSnapshotRequest(
+                    context.metadataPartition,
+                    epoch,
+                    snapshotId,
+                    Integer.MAX_VALUE,
+                    responseBuffer.position()
+                )
+            );
+
+            context.client.poll();
+
+            response = 
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
+            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+            assertEquals(snapshot.sizeInBytes(), response.size());
+            assertEquals(responseBuffer.position(), response.position());
+            assertEquals(snapshot.sizeInBytes() - (snapshot.sizeInBytes() / 
2), response.bytes().remaining());
+
+            responseBuffer.put(response.bytes());
+            assertEquals(snapshotBuffer, responseBuffer.flip());
+        }
+    }
+
+    @Test
+    public void testFetchSnapshotRequestAsFollower() throws IOException {
+        int localId = 0;
+        int leaderId = localId + 1;
+        Set<Integer> voters = Utils.mkSet(localId, leaderId);
+        int epoch = 2;
+        OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0);
+
+        RaftClientTestContext context = new 
RaftClientTestContext.Builder(localId, voters)
+            .withElectedLeader(epoch, leaderId)
+            .build();
+
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                snapshotId,
+                Integer.MAX_VALUE,
+                0
+            )
+        );
+
+        context.client.poll();
+
+        FetchSnapshotResponseData.PartitionSnapshot response =  
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
+        assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, 
Errors.forCode(response.errorCode()));
+        assertEquals(epoch, response.currentLeader().leaderEpoch());
+        assertEquals(leaderId, response.currentLeader().leaderId());
+    }
+
+    @Disabled
+    @Test
+    public void testFetchSnapshotRequestWithOlderEpoch() throws IOException {
+        assertTrue(false);

Review comment:
       I guess you're still planning to implement these?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to