This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e89dad95 RATIS-589. Eliminate buffer copying in 
SegmentedRaftLogOutputStream. (#964)
2e89dad95 is described below

commit 2e89dad950e0cc37d6d08c7ecc92a910ad50f518
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Wed Nov 15 18:12:58 2023 -0800

    RATIS-589. Eliminate buffer copying in SegmentedRaftLogOutputStream. (#964)
---
 .../main/java/org/apache/ratis/util/IOUtils.java   |  21 +--
 .../java/org/apache/ratis/util/PureJavaCrc32C.java |  60 ++++++++-
 .../org/apache/ratis/util/TestPureJavaCrc32C.java  |  59 +++++++++
 ratis-docs/src/site/markdown/configurations.md     |   2 +-
 .../apache/ratis/server/RaftServerConfigKeys.java  |   2 +-
 .../raftlog/segmented/BufferedWriteChannel.java    |  94 +++++++++----
 .../raftlog/segmented/SegmentedRaftLogFormat.java  |  28 ++--
 .../segmented/SegmentedRaftLogOutputStream.java    |  65 ++++-----
 .../raftlog/segmented/SegmentedRaftLogReader.java  |   5 +-
 .../raftlog/segmented/SegmentedRaftLogWorker.java  |  15 ++-
 .../apache/ratis/server/ServerRestartTests.java    |  15 +--
 .../segmented/TestBufferedWriteChannel.java        | 147 ++++++++++++++++-----
 .../raftlog/segmented/TestSegmentedRaftLog.java    |  19 ---
 13 files changed, 365 insertions(+), 167 deletions(-)

diff --git a/ratis-common/src/main/java/org/apache/ratis/util/IOUtils.java 
b/ratis-common/src/main/java/org/apache/ratis/util/IOUtils.java
index 0153ac491..f1fe6c35c 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/IOUtils.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/IOUtils.java
@@ -126,22 +126,6 @@ public interface IOUtils {
     }
   }
 
-  /**
-   * Write a ByteBuffer to a FileChannel at a given offset,
-   * handling short writes.
-   *
-   * @param fc               The FileChannel to write to
-   * @param buf              The input buffer
-   * @param offset           The offset in the file to start writing at
-   * @throws IOException     On I/O error
-   */
-  static void writeFully(FileChannel fc, ByteBuffer buf, long offset)
-      throws IOException {
-    do {
-      offset += fc.write(buf, offset);
-    } while (buf.remaining() > 0);
-  }
-
   static long preallocate(FileChannel fc, long size, ByteBuffer fill) throws 
IOException {
     Preconditions.assertSame(0, fill.position(), "fill.position");
     Preconditions.assertSame(fill.capacity(), fill.limit(), "fill.limit");
@@ -153,8 +137,9 @@ public interface IOUtils {
       final int n = remaining < required? remaining: Math.toIntExact(required);
       final ByteBuffer buffer = fill.slice();
       buffer.limit(n);
-      IOUtils.writeFully(fc, buffer, fc.size());
-      allocated += n;
+
+      final int written = fc.write(buffer, fc.size());
+      allocated += written;
     }
     return allocated;
   }
diff --git 
a/ratis-common/src/main/java/org/apache/ratis/util/PureJavaCrc32C.java 
b/ratis-common/src/main/java/org/apache/ratis/util/PureJavaCrc32C.java
index ef01f0474..bdb000290 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/PureJavaCrc32C.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/PureJavaCrc32C.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -22,6 +22,8 @@ package org.apache.ratis.util;
 
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.zip.Checksum;
 
 /**
@@ -91,6 +93,62 @@ public class PureJavaCrc32C implements Checksum {
     crc = localCrc;
   }
 
+  @SuppressFBWarnings("SF_SWITCH_NO_DEFAULT")
+  public void update(ByteBuffer b) {
+    int localCrc = crc;
+
+    b.order(ByteOrder.LITTLE_ENDIAN);
+    int off = b.position();
+    int len = b.remaining();
+    while(len > 7) {
+      final long value = b.getLong(off);
+      final int m = (int) value;
+      final int n = (int)(value >> 4*8);
+      final int c0 =((m >> 0*8) ^ (localCrc >> 0*8)) & 0xff;
+      final int c1 =((m >> 1*8) ^ (localCrc >> 1*8)) & 0xff;
+      final int c2 =((m >> 2*8) ^ (localCrc >> 2*8)) & 0xff;
+      final int c3 =((m >> 3*8) ^ (localCrc >> 3*8)) & 0xff;
+      final int c4 = (n >> 0*8) & 0xff;
+      final int c5 = (n >> 1*8) & 0xff;
+      final int c6 = (n >> 2*8) & 0xff;
+      final int c7 = (n >> 3*8) & 0xff;
+
+      localCrc = (T[T8_7_START + c0] ^ T[T8_6_START + c1])
+          ^ (T[T8_5_START + c2] ^ T[T8_4_START + c3])
+          ^ (T[T8_3_START + c4] ^ T[T8_2_START + c5])
+          ^ (T[T8_1_START + c6] ^ T[T8_0_START + c7]);
+
+      off += 8;
+      len -= 8;
+    }
+
+    if (len > 3) {
+      final int n = b.getInt(off);
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 0*8)) & 
0xff)];
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 1*8)) & 
0xff)];
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 2*8)) & 
0xff)];
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 3*8)) & 
0xff)];
+
+      off += 4;
+      len -= 4;
+    }
+
+    if (len > 1) {
+      final int n = b.getShort(off);
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 0*8)) & 
0xff)];
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ (n >> 1*8)) & 
0xff)];
+
+      off += 2;
+      len -= 2;
+    }
+
+    if (len > 0) {
+      localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b.get(off)) & 
0xff)];
+    }
+    // Publish crc out to object
+    crc = localCrc;
+  }
+
   @Override
   public final void update(int b) {
     crc = (crc >>> 8) ^ T[T8_0_START + ((crc ^ b) & 0xff)];
diff --git 
a/ratis-common/src/test/java/org/apache/ratis/util/TestPureJavaCrc32C.java 
b/ratis-common/src/test/java/org/apache/ratis/util/TestPureJavaCrc32C.java
new file mode 100644
index 000000000..5a695fd84
--- /dev/null
+++ b/ratis-common/src/test/java/org/apache/ratis/util/TestPureJavaCrc32C.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.util;
+
+import org.apache.ratis.BaseTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+/** Testing {@link PureJavaCrc32C}. */
+public class TestPureJavaCrc32C extends BaseTest {
+  static final ThreadLocalRandom RANDOM = ThreadLocalRandom.current();
+
+  @Test
+  public void testByteBuffer() {
+    for(int length = 1; length < 1 << 20; length <<= 2) {
+      runTestByteBuffer(length - 1);
+      runTestByteBuffer(length);
+      runTestByteBuffer(length + 1);
+    }
+  }
+
+  /** Verify if the CRC computed by {@link ByteBuffer}s is the same as the CRC 
computed by arrays. */
+  static void runTestByteBuffer(int length) {
+    final byte[] array = new byte[length];
+    RANDOM.nextBytes(array);
+    final ByteBuffer buffer = ByteBuffer.wrap(array);
+
+    final PureJavaCrc32C arrayCrc = new PureJavaCrc32C();
+    final PureJavaCrc32C bufferCrc = new PureJavaCrc32C();
+    for (int off = 0; off < array.length; ) {
+      final int len = RANDOM.nextInt(array.length - off) + 1;
+      arrayCrc.update(array, off, len);
+
+      buffer.position(off).limit(off + len);
+      bufferCrc.update(buffer);
+
+      Assert.assertEquals(arrayCrc.getValue(), bufferCrc.getValue());
+      off += len;
+    }
+  }
+}
diff --git a/ratis-docs/src/site/markdown/configurations.md 
b/ratis-docs/src/site/markdown/configurations.md
index dd953e7fd..7c8fb001f 100644
--- a/ratis-docs/src/site/markdown/configurations.md
+++ b/ratis-docs/src/site/markdown/configurations.md
@@ -342,7 +342,7 @@ Ratis will temporarily stall the new IO Tasks.
 
|:----------------|:------------------------------------------------------------|
 | **Description** | size of direct byte buffer for SegmentedRaftLog 
FileChannel |
 | **Type**        | SizeInBytes                                                
 |
-| **Default**     | 64KB                                                       
 |
+| **Default**     | 8MB                                                        
 |
 
 | **Property**    | `raft.server.log.force.sync.num`                           
                     |
 
|:----------------|:--------------------------------------------------------------------------------|
diff --git 
a/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
 
b/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
index f2de074a8..565e88126 100644
--- 
a/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
+++ 
b/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
@@ -424,7 +424,7 @@ public interface RaftServerConfigKeys {
     }
 
     String WRITE_BUFFER_SIZE_KEY = PREFIX + ".write.buffer.size";
-    SizeInBytes WRITE_BUFFER_SIZE_DEFAULT =SizeInBytes.valueOf("64KB");
+    SizeInBytes WRITE_BUFFER_SIZE_DEFAULT = SizeInBytes.valueOf("8MB");
     static SizeInBytes writeBufferSize(RaftProperties properties) {
       return getSizeInBytes(properties::getSizeInBytes,
           WRITE_BUFFER_SIZE_KEY, WRITE_BUFFER_SIZE_DEFAULT, getDefaultLog());
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/BufferedWriteChannel.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/BufferedWriteChannel.java
index fd06a2b37..7ad1a48ee 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/BufferedWriteChannel.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/BufferedWriteChannel.java
@@ -20,6 +20,9 @@ package org.apache.ratis.server.raftlog.segmented;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.function.CheckedBiFunction;
+import org.apache.ratis.util.function.CheckedConsumer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.File;
@@ -34,24 +37,33 @@ import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * Provides a buffering layer in front of a FileChannel for writing.
- *
+ * <p>
  * This class is NOT threadsafe.
  */
 class BufferedWriteChannel implements Closeable {
+  static final Logger LOG = 
LoggerFactory.getLogger(BufferedWriteChannel.class);
 
   @SuppressWarnings("java:S2095") // return Closable
   static BufferedWriteChannel open(File file, boolean append, ByteBuffer 
buffer) throws IOException {
     final RandomAccessFile raf = new RandomAccessFile(file, "rw");
     final FileChannel fc = raf.getChannel();
+    final long size = file.length(); // 0L if the file does not exist.
     if (append) {
-      fc.position(fc.size());
+      fc.position(size);
+      Preconditions.assertSame(size, fc.size(), "fc.size");
     } else {
-      fc.truncate(0);
+      if (size > 0) {
+        fc.truncate(0);
+      }
+      Preconditions.assertSame(0, fc.size(), "fc.size");
     }
     Preconditions.assertSame(fc.size(), fc.position(), "fc.position");
-    return new BufferedWriteChannel(fc, buffer);
+    final String name = file.getName() + (append? " (append)": "");
+    LOG.info("open {} at position {}", name, fc.position());
+    return new BufferedWriteChannel(name, fc, buffer);
   }
 
+  private final String name;
   private final FileChannel fileChannel;
   private final ByteBuffer writeBuffer;
   private boolean forced = true;
@@ -59,29 +71,57 @@ class BufferedWriteChannel implements Closeable {
       = new AtomicReference<>(CompletableFuture.completedFuture(null));
 
 
-  BufferedWriteChannel(FileChannel fileChannel, ByteBuffer byteBuffer) {
+  BufferedWriteChannel(String name, FileChannel fileChannel, ByteBuffer 
byteBuffer) {
+    this.name = name;
     this.fileChannel = fileChannel;
     this.writeBuffer = byteBuffer;
   }
 
-  void write(byte[] b) throws IOException {
-    write(b, b.length);
+  int writeBufferPosition() {
+    return writeBuffer.position();
   }
-  void write(byte[] b, int len) throws IOException {
-    int offset = 0;
-    while (offset < len) {
-      int toPut = Math.min(len - offset, writeBuffer.remaining());
-      writeBuffer.put(b, offset, toPut);
-      offset += toPut;
-      if (writeBuffer.remaining() == 0) {
-        flushBuffer();
-      }
+
+  /**
+   * Write to buffer.
+   *
+   * @param writeSize the size to write.
+   * @param writeMethod write exactly the writeSize of bytes to the buffer and 
advance buffer position.
+   */
+  void writeToBuffer(int writeSize, CheckedConsumer<ByteBuffer, IOException> 
writeMethod) throws IOException {
+    if (writeSize > writeBuffer.capacity()) {
+      throw new IOException("writeSize = " + writeSize
+          + " > writeBuffer.capacity() = " + writeBuffer.capacity());
+    }
+    if (writeSize > writeBuffer.remaining()) {
+      flushBuffer();
     }
+    final int pos = writeBufferPosition();
+    final int lim = writeBuffer.limit();
+    writeMethod.accept(writeBuffer);
+    final int written = writeBufferPosition() - pos;
+    Preconditions.assertSame(writeSize, written, "written");
+    Preconditions.assertSame(lim, writeBuffer.limit(), "writeBuffer.limit()");
+  }
+
+  /** Write the content of the given buffer to {@link #fileChannel}. */
+  void writeToChannel(ByteBuffer buffer) throws IOException {
+    Preconditions.assertSame(0, writeBufferPosition(), 
"writeBuffer.position()");
+    final int length = buffer.remaining();
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Write {} bytes (pos={}, size={}) to channel {}",
+          length, fileChannel.position(), fileChannel.size(), this);
+    }
+    int written = 0;
+    for(; written < length; ) {
+      written += fileChannel.write(buffer);
+    }
+    Preconditions.assertSame(length, written, "written");
+    forced = false;
   }
 
   void preallocateIfNecessary(long size, CheckedBiFunction<FileChannel, Long, 
Long, IOException> preallocate)
       throws IOException {
-    final long outstanding = writeBuffer.position() + size;
+    final long outstanding = writeBufferPosition() + size;
     if (fileChannel.position() + outstanding > fileChannel.size()) {
       preallocate.apply(fileChannel, outstanding);
     }
@@ -115,26 +155,19 @@ class BufferedWriteChannel implements Closeable {
     try {
       fileChannel.force(false);
     } catch (IOException e) {
-      LogSegment.LOG.error("Failed to flush channel", e);
-      throw new CompletionException(e);
+      throw new CompletionException("Failed to force channel " + this, e);
     }
     return null;
   }
 
-  /**
-   * Write any data in the buffer to the file.
-   *
-   * @throws IOException if the write fails.
-   */
+  /** Flush the data from the {@link #writeBuffer} to {@link #fileChannel}. */
   private void flushBuffer() throws IOException {
-    if (writeBuffer.position() == 0) {
+    if (writeBufferPosition() == 0) {
       return; // nothing to flush
     }
 
     writeBuffer.flip();
-    do {
-      fileChannel.write(writeBuffer);
-    } while (writeBuffer.hasRemaining());
+    writeToChannel(writeBuffer);
     writeBuffer.clear();
     forced = false;
   }
@@ -157,4 +190,9 @@ class BufferedWriteChannel implements Closeable {
       fileChannel.close();
     }
   }
+
+  @Override
+  public String toString() {
+    return name;
+  }
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogFormat.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogFormat.java
index 57c6e8874..d55f07916 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogFormat.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogFormat.java
@@ -18,43 +18,41 @@
 package org.apache.ratis.server.raftlog.segmented;
 
 import org.apache.ratis.util.Preconditions;
-import org.apache.ratis.util.function.CheckedFunction;
 
-import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
 
 public interface SegmentedRaftLogFormat {
   class Internal {
-    private static final byte[] HEADER_BYTES = 
"RaftLog1".getBytes(StandardCharsets.UTF_8);
-    private static final byte[] HEADER_BYTES_CLONE = HEADER_BYTES.clone();
+    private static final ByteBuffer HEADER;
     private static final byte TERMINATOR_BYTE = 0;
 
-    private static void assertHeader() {
-      Preconditions.assertTrue(Arrays.equals(HEADER_BYTES, 
HEADER_BYTES_CLONE));
+    static {
+      final byte[] bytes = "RaftLog1".getBytes(StandardCharsets.UTF_8);
+      final ByteBuffer header = ByteBuffer.allocateDirect(bytes.length);
+      header.put(bytes).flip();
+      HEADER = header.asReadOnlyBuffer();
     }
   }
 
   static int getHeaderLength() {
-    return Internal.HEADER_BYTES.length;
+    return Internal.HEADER.remaining();
+  }
+
+  static ByteBuffer getHeaderBytebuffer() {
+    return Internal.HEADER.duplicate();
   }
 
   static int matchHeader(byte[] bytes, int offset, int length) {
     Preconditions.assertTrue(length <= getHeaderLength());
     for(int i = 0; i < length; i++) {
-      if (bytes[offset + i] != Internal.HEADER_BYTES[i]) {
+      if (bytes[offset + i] != Internal.HEADER.get(i)) {
         return i;
       }
     }
     return length;
   }
 
-  static <T> T applyHeaderTo(CheckedFunction<byte[], T, IOException> function) 
throws IOException {
-    final T t = function.apply(Internal.HEADER_BYTES);
-    Internal.assertHeader(); // assert that the header is unmodified by the 
function.
-    return t;
-  }
-
   static byte getTerminator() {
     return Internal.TERMINATOR_BYTE;
   }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogOutputStream.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogOutputStream.java
index e0fd41fbd..ba564f505 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogOutputStream.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogOutputStream.java
@@ -23,7 +23,6 @@ import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.PureJavaCrc32C;
-import org.apache.ratis.util.function.CheckedConsumer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -34,8 +33,6 @@ import java.nio.ByteBuffer;
 import java.nio.channels.FileChannel;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
-import java.util.function.Supplier;
-import java.util.zip.Checksum;
 
 public class SegmentedRaftLogOutputStream implements Closeable {
   private static final Logger LOG = 
LoggerFactory.getLogger(SegmentedRaftLogOutputStream.class);
@@ -43,17 +40,17 @@ public class SegmentedRaftLogOutputStream implements 
Closeable {
   private static final ByteBuffer FILL;
   private static final int BUFFER_SIZE = 1024 * 1024; // 1 MB
   static {
-    FILL = ByteBuffer.allocateDirect(BUFFER_SIZE);
-    for (int i = 0; i < FILL.capacity(); i++) {
-      FILL.put(SegmentedRaftLogFormat.getTerminator());
+    final ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_SIZE);
+    for (int i = 0; i < BUFFER_SIZE; i++) {
+      buffer.put(SegmentedRaftLogFormat.getTerminator());
     }
-    FILL.flip();
+    buffer.flip();
+    FILL = buffer.asReadOnlyBuffer();
   }
 
-  private final File file;
+  private final String name;
   private final BufferedWriteChannel out; // buffered FileChannel for writing
-  private final Checksum checksum;
-  private final Supplier<byte[]> sharedBuffer;
+  private final PureJavaCrc32C checksum = new PureJavaCrc32C();
 
   private final long segmentMaxSize;
   private final long preallocatedSize;
@@ -61,23 +58,15 @@ public class SegmentedRaftLogOutputStream implements 
Closeable {
   public SegmentedRaftLogOutputStream(File file, boolean append, long 
segmentMaxSize,
       long preallocatedSize, ByteBuffer byteBuffer)
       throws IOException {
-    this(file, append, segmentMaxSize, preallocatedSize, byteBuffer, null);
-  }
-
-  SegmentedRaftLogOutputStream(File file, boolean append, long segmentMaxSize,
-      long preallocatedSize, ByteBuffer byteBuffer, Supplier<byte[]> 
sharedBuffer)
-      throws IOException {
-    this.file = file;
-    this.checksum = new PureJavaCrc32C();
+    this.name = JavaUtils.getClassSimpleName(getClass()) + "(" + 
file.getName() + ")";
     this.segmentMaxSize = segmentMaxSize;
     this.preallocatedSize = preallocatedSize;
-    this.sharedBuffer = sharedBuffer;
     this.out = BufferedWriteChannel.open(file, append, byteBuffer);
 
     if (!append) {
       // write header
       preallocateIfNecessary(SegmentedRaftLogFormat.getHeaderLength());
-      
SegmentedRaftLogFormat.applyHeaderTo(CheckedConsumer.asCheckedFunction(out::write));
+      out.writeToChannel(SegmentedRaftLogFormat.getHeaderBytebuffer());
       out.flush();
     }
   }
@@ -98,19 +87,26 @@ public class SegmentedRaftLogOutputStream implements 
Closeable {
     final int serialized = entry.getSerializedSize();
     final int proto = CodedOutputStream.computeUInt32SizeNoTag(serialized) + 
serialized;
     final int total = proto + 4; // proto and 4-byte checksum
-    final byte[] buf = sharedBuffer != null? sharedBuffer.get(): new 
byte[total];
-    Preconditions.assertTrue(total <= buf.length, () -> "total = " + total + " 
> buf.length " + buf.length);
     preallocateIfNecessary(total);
 
-    CodedOutputStream cout = CodedOutputStream.newInstance(buf);
-    cout.writeUInt32NoTag(serialized);
-    entry.writeTo(cout);
+    out.writeToBuffer(total, buf -> {
+      final int pos = buf.position();
+      final int protoEndPos= pos + proto;
+
+      final CodedOutputStream encoder = CodedOutputStream.newInstance(buf);
+      encoder.writeUInt32NoTag(serialized);
+      entry.writeTo(encoder);
 
-    checksum.reset();
-    checksum.update(buf, 0, proto);
-    ByteBuffer.wrap(buf, proto, 4).putInt((int) checksum.getValue());
+      // compute checksum
+      final ByteBuffer duplicated = buf.duplicate();
+      duplicated.position(pos).limit(protoEndPos);
+      checksum.reset();
+      checksum.update(duplicated);
 
-    out.write(buf, total);
+      buf.position(protoEndPos);
+      buf.putInt((int) checksum.getValue());
+      Preconditions.assertSame(pos + total, buf.position(), "buf.position()");
+    });
   }
 
   @Override
@@ -153,10 +149,15 @@ public class SegmentedRaftLogOutputStream implements 
Closeable {
   }
 
   private long preallocate(FileChannel fc, long outstanding) throws 
IOException {
-    final long actual = actualPreallocateSize(outstanding, segmentMaxSize - 
fc.size(), preallocatedSize);
+    final long size = fc.size();
+    final long actual = actualPreallocateSize(outstanding, segmentMaxSize - 
size, preallocatedSize);
     Preconditions.assertTrue(actual >= outstanding);
+    final long pos = fc.position();
+    LOG.debug("Preallocate {} bytes (pos={}, size={}) for {}", actual, pos, 
size, this);
     final long allocated = IOUtils.preallocate(fc, actual, FILL);
-    LOG.debug("Pre-allocated {} bytes for {}", allocated, this);
+    Preconditions.assertSame(pos, fc.position(), "fc.position()");
+    Preconditions.assertSame(actual, allocated, "allocated");
+    Preconditions.assertSame(size + allocated, fc.size(), "fc.size()");
     return allocated;
   }
 
@@ -166,6 +167,6 @@ public class SegmentedRaftLogOutputStream implements 
Closeable {
 
   @Override
   public String toString() {
-    return JavaUtils.getClassSimpleName(getClass()) + "(" + file + ")";
+    return name;
   }
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogReader.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogReader.java
index c8170f9b1..7d03105b9 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogReader.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogReader.java
@@ -185,7 +185,7 @@ class SegmentedRaftLogReader implements Closeable {
     throw new CorruptedFileException(file, "Log header mismatched: expected 
header length="
         + SegmentedRaftLogFormat.getHeaderLength() + ", read length=" + 
readLength + ", match length=" + matchLength
         + ", header in file=" + StringUtils.bytes2HexString(temp, 0, 
readLength)
-        + ", expected header=" + 
SegmentedRaftLogFormat.applyHeaderTo(StringUtils::bytes2HexString));
+        + ", expected header=" + 
StringUtils.bytes2HexString(SegmentedRaftLogFormat.getHeaderBytebuffer()));
   }
 
   /**
@@ -245,7 +245,8 @@ class SegmentedRaftLogReader implements Closeable {
         }
         for (idx = 0; idx < numRead; idx++) {
           if (!SegmentedRaftLogFormat.isTerminator(temp[idx])) {
-            throw new IOException("Read extra bytes after the terminator!");
+            throw new IOException("Read extra bytes after the terminator at 
position "
+                + (limiter.getPos() - numRead + idx) + " in " + file);
           }
         }
       } finally {
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
index 18fd68012..0e8d0f3b7 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
@@ -50,7 +50,6 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.Queue;
 import java.util.concurrent.*;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
@@ -149,7 +148,6 @@ class SegmentedRaftLogWorker {
   private final StateMachine stateMachine;
   private final SegmentedRaftLogMetrics raftLogMetrics;
   private final ByteBuffer writeBuffer;
-  private final AtomicReference<byte[]> sharedBuffer;
 
   /**
    * The number of entries that have been written into the 
SegmentedRaftLogOutputStream but
@@ -210,7 +208,12 @@ class SegmentedRaftLogWorker {
     this.writeBuffer = ByteBuffer.allocateDirect(bufferSize);
     final int logEntryLimit = 
RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties).getSizeInt();
     // 4 bytes (serialized size) + logEntryLimit + 4 bytes (checksum)
-    this.sharedBuffer = new AtomicReference<>(new byte[logEntryLimit + 8]);
+    if (bufferSize < logEntryLimit + 8) {
+      throw new 
IllegalArgumentException(RaftServerConfigKeys.Log.WRITE_BUFFER_SIZE_KEY
+          + " (= " + bufferSize
+          + ") is less than " + 
RaftServerConfigKeys.Log.Appender.BUFFER_BYTE_LIMIT_KEY
+          + " + 8 (= " + (logEntryLimit + 8) + ")");
+    }
     this.unsafeFlush = RaftServerConfigKeys.Log.unsafeFlushEnabled(properties);
     this.asyncFlush = RaftServerConfigKeys.Log.asyncFlushEnabled(properties);
     if (asyncFlush && unsafeFlush) {
@@ -235,7 +238,6 @@ class SegmentedRaftLogWorker {
 
   void close() {
     this.running = false;
-    sharedBuffer.set(null);
     Optional.ofNullable(flushExecutor).ifPresent(ExecutorService::shutdown);
     ConcurrentUtils.shutdownAndWait(TimeDuration.ONE_SECOND.multiply(3),
         workerThreadExecutor, timeout -> LOG.warn("{}: shutdown timeout in " + 
timeout, name));
@@ -730,8 +732,9 @@ class SegmentedRaftLogWorker {
   }
 
   private void allocateSegmentedRaftLogOutputStream(File file, boolean append) 
throws IOException {
-    Preconditions.assertTrue(out == null && writeBuffer.position() == 0);
+    Preconditions.assertNull(out, "out");
+    Preconditions.assertSame(0, writeBuffer.position(), 
"writeBuffer.position()");
     out = new SegmentedRaftLogOutputStream(file, append, segmentMaxSize,
-        preallocatedSize, writeBuffer, sharedBuffer::get);
+        preallocatedSize, writeBuffer);
   }
 }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java 
b/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
index db5cdc34c..73ff1eb53 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
@@ -54,6 +54,7 @@ import org.slf4j.event.Level;
 import java.io.File;
 import java.io.IOException;
 import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
 import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.List;
@@ -218,14 +219,12 @@ public abstract class ServerRestartTests<CLUSTER extends 
MiniRaftCluster>
       MiniRaftCluster cluster, Logger LOG) throws Exception {
     Preconditions.assertTrue(partialLength < 
SegmentedRaftLogFormat.getHeaderLength());
     try(final RandomAccessFile raf = new RandomAccessFile(openLogFile, "rw")) {
-      SegmentedRaftLogFormat.applyHeaderTo(header -> {
-        LOG.info("header    = {}", StringUtils.bytes2HexString(header));
-        final byte[] corrupted = new byte[header.length];
-        System.arraycopy(header, 0, corrupted, 0, partialLength);
-        LOG.info("corrupted = {}", StringUtils.bytes2HexString(corrupted));
-        raf.write(corrupted);
-        return null;
-      });
+      final ByteBuffer header = SegmentedRaftLogFormat.getHeaderBytebuffer();
+      LOG.info("header    = {}", StringUtils.bytes2HexString(header));
+      final byte[] corrupted = new byte[header.remaining()];
+      header.get(corrupted, 0, partialLength);
+      LOG.info("corrupted = {}", StringUtils.bytes2HexString(corrupted));
+      raf.write(corrupted);
     }
     final RaftServer.Division server = cluster.restartServer(id, false);
     server.getRaftServer().close();
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestBufferedWriteChannel.java
 
b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestBufferedWriteChannel.java
index a7bb7000a..c9d792855 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestBufferedWriteChannel.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestBufferedWriteChannel.java
@@ -18,15 +18,18 @@
 package org.apache.ratis.server.raftlog.segmented;
 
 import org.apache.ratis.BaseTest;
+import org.apache.ratis.util.StringUtils;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.MappedByteBuffer;
 import java.nio.channels.FileChannel;
 import java.nio.channels.FileLock;
 import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.WritableByteChannel;
+import java.util.concurrent.atomic.AtomicInteger;
 
 /**
  * Test {@link BufferedWriteChannel}
@@ -71,13 +74,12 @@ public class TestBufferedWriteChannel extends BaseTest {
       final int remaining = src.remaining();
       LOG.info("write {} bytes", remaining);
       position += remaining;
-      src.position(src.limit());
       return remaining;
     }
 
     @Override
     public long position() {
-      throw new UnsupportedOperationException();
+      return position;
     }
 
     @Override
@@ -132,55 +134,128 @@ public class TestBufferedWriteChannel extends BaseTest {
     }
   }
 
+  static ByteBuffer allocateByteBuffer(int size) {
+    final ByteBuffer buffer = ByteBuffer.allocate(size);
+    for(int i = 0; i < size; i++) {
+      buffer.put(i, (byte)i);
+    }
+    return buffer.asReadOnlyBuffer();
+  }
+
+
   @Test
-  public void testFlush() throws Exception {
-    final byte[] bytes = new byte[10];
-    final ByteBuffer buffer = ByteBuffer.wrap(bytes);
+  public void testWriteToChannel() throws Exception {
+    for(int n = 1; n < 1 << 20; n <<=2) {
+      runTestWriteToChannel(n - 1);
+      runTestWriteToChannel(n);
+      runTestWriteToChannel(n + 1);
+    }
+  }
+
+  void runTestWriteToChannel(int bufferSize) throws Exception {
+    final ByteBuffer buffer = allocateByteBuffer(2 * bufferSize);
     final FakeFileChannel fake = new FakeFileChannel();
-    final BufferedWriteChannel out = new BufferedWriteChannel(fake, buffer);
+    final BufferedWriteChannel out = new BufferedWriteChannel("test", fake, 
ByteBuffer.allocate(0));
 
-    // write exactly buffer size, then flush.
     fake.assertValues(0, 0);
-    out.write(bytes);
-    int pos = bytes.length;
-    fake.assertValues(pos,  0);
-    out.flush();
-    int force = pos;
-    fake.assertValues(pos, force);
+    final AtomicInteger pos = new AtomicInteger();
+    final AtomicInteger force = new AtomicInteger();
+
+    {
+      // write exactly buffer size, then flush.
+      writeToChannel(out, fake, pos, force, buffer, bufferSize);
+      flush(out, fake, pos, force);
+    }
 
     {
       // write less than buffer size, then flush.
-      int n = bytes.length/2;
-      out.write(new byte[n]);
-      fake.assertValues(pos, force);
-      out.flush();
-      pos += n;
-      force = pos;
-      fake.assertValues(pos, force);
+      writeToChannel(out, fake, pos, force, buffer, bufferSize/2);
+      flush(out, fake, pos, force);
     }
 
     {
       // write less than buffer size twice, then flush.
-      int n = bytes.length*2/3;
-      out.write(new byte[n]);
-      fake.assertValues(pos, force);
-      out.write(new byte[n]);
-      fake.assertValues(pos + bytes.length, force);
-      out.flush();
-      pos += 2*n;
-      force = pos;
-      fake.assertValues(pos, force);
+      final int n = bufferSize*2/3;
+      writeToChannel(out, fake, pos, force, buffer, n);
+      writeToChannel(out, fake, pos, force, buffer, n);
+      flush(out, fake, pos, force);
     }
 
     {
       // write more than buffer size, then flush.
-      int n = bytes.length*3/2;
-      out.write(new byte[n]);
-      fake.assertValues(pos + bytes.length, force);
-      out.flush();
-      pos += n;
-      force = pos;
-      fake.assertValues(pos, force);
+      writeToChannel(out, fake, pos, force, buffer, bufferSize*3/2);
+      flush(out, fake, pos, force);
+    }
+  }
+
+  static void writeToChannel(BufferedWriteChannel out, FakeFileChannel fake, 
AtomicInteger pos, AtomicInteger force,
+      ByteBuffer buffer, int n) throws IOException {
+    buffer.position(0).limit(n);
+    out.writeToChannel(buffer);
+    pos.addAndGet(n);
+    fake.assertValues(pos.get(), force.get());
+  }
+
+  static void flush(BufferedWriteChannel out, FakeFileChannel fake,
+      AtomicInteger pos, AtomicInteger force) throws IOException {
+    final int existing = out.writeBufferPosition();
+    out.flush();
+    Assert.assertEquals(0, out.writeBufferPosition());
+    pos.addAndGet(existing);
+    force.set(pos.get());
+    fake.assertValues(pos.get(), force.get());
+  }
+
+  static void writeToBuffer(BufferedWriteChannel out, FakeFileChannel fake, 
AtomicInteger pos, AtomicInteger force,
+      int bufferCapacity, ByteBuffer buffer, int n) throws IOException {
+    final int existing = out.writeBufferPosition();
+    buffer.position(0).limit(n);
+    out.writeToBuffer(n, b -> b.put(buffer));
+    if (existing + n > bufferCapacity) {
+      pos.addAndGet(existing);
+      Assert.assertEquals(n, out.writeBufferPosition());
+    } else {
+      Assert.assertEquals(existing + n, out.writeBufferPosition());
+    }
+    fake.assertValues(pos.get(), force.get());
+  }
+
+  @Test
+  public void testWriteToBuffer() throws Exception {
+    for(int n = 1; n < 1 << 20; n <<=2) {
+      runTestWriteToBuffer(n - 1);
+      runTestWriteToBuffer(n);
+      runTestWriteToBuffer(n + 1);
+    }
+  }
+
+  void runTestWriteToBuffer(int bufferSize) throws Exception {
+    final ByteBuffer buffer = allocateByteBuffer(2 * bufferSize);
+    final FakeFileChannel fake = new FakeFileChannel();
+    final BufferedWriteChannel out = new BufferedWriteChannel("test", fake, 
ByteBuffer.allocate(bufferSize));
+
+    fake.assertValues(0, 0);
+    final AtomicInteger pos = new AtomicInteger();
+    final AtomicInteger force = new AtomicInteger();
+
+    {
+      // write exactly buffer size, then flush.
+      writeToBuffer(out, fake, pos, force, bufferSize, buffer, bufferSize);
+      flush(out, fake, pos, force);
+    }
+
+    {
+      // write less than buffer size, then flush.
+      writeToBuffer(out, fake, pos, force, bufferSize, buffer, bufferSize/2);
+      flush(out, fake, pos, force);
+    }
+
+    {
+      // write less than buffer size twice, then flush.
+      final int n = bufferSize*2/3;
+      writeToBuffer(out, fake, pos, force, bufferSize, buffer, n);
+      writeToBuffer(out, fake, pos, force, bufferSize, buffer, n);
+      flush(out, fake, pos, force);
     }
   }
 }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
 
b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
index abc36e4ee..38fa45e6f 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
@@ -698,25 +698,6 @@ public class TestSegmentedRaftLog extends BaseTest {
         10, HUNDRED_MILLIS, "assertIndices", LOG);
   }
 
-  @Test
-  public void testSegmentedRaftLogFormatInternalHeader() throws Exception {
-    testFailureCase("testSegmentedRaftLogFormatInternalHeader",
-        () -> SegmentedRaftLogFormat.applyHeaderTo(header -> {
-          LOG.info("header  = " + new String(header, StandardCharsets.UTF_8));
-          header[0]++; // try changing the internal header
-          LOG.info("header' = " + new String(header, StandardCharsets.UTF_8));
-          return null;
-        }), IllegalStateException.class);
-
-    // reset the header
-    SegmentedRaftLogFormat.applyHeaderTo(header -> {
-      LOG.info("header'  = " + new String(header, StandardCharsets.UTF_8));
-      header[0] -= 1; // try changing the internal header
-      LOG.info("header'' = " + new String(header, StandardCharsets.UTF_8));
-      return null;
-    });
-  }
-
   @Test
   public void testAsyncFlushPerf1() throws Exception {
     List<SegmentRange> ranges = prepareRanges(0, 50, 20000, 0);


Reply via email to