You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by ru...@apache.org on 2020/11/21 02:16:07 UTC

[incubator-ratis] branch master updated: RATIS-1168. Move the utils from DataStreamBaseTest to a new file. (#292)

This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 507f09d  RATIS-1168. Move the utils from DataStreamBaseTest to a new file. (#292)
507f09d is described below

commit 507f09dfabbfd9eb25852d1bd24fbff9377909bc
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Sat Nov 21 10:15:59 2020 +0800

    RATIS-1168. Move the utils from DataStreamBaseTest to a new file. (#292)
---
 .../ratis/datastream/DataStreamBaseTest.java       | 251 ++++-----------------
 .../ratis/datastream/DataStreamClusterTests.java   |   6 +-
 .../ratis/datastream/DataStreamTestUtils.java      | 245 ++++++++++++++++++++
 .../ratis/datastream/TestDataStreamNetty.java      |   4 +-
 4 files changed, 295 insertions(+), 211 deletions(-)

diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
index 17f675a..113fcb6 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
@@ -23,18 +23,25 @@ import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
 import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.datastream.DataStreamTestUtils.DataChannel;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.SingleDataStream;
 import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
-import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
-import org.apache.ratis.proto.RaftProtos.*;
-import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto;
+import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto;
+import org.apache.ratis.proto.RaftProtos.InstallSnapshotReplyProto;
+import org.apache.ratis.proto.RaftProtos.InstallSnapshotRequestProto;
+import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
+import org.apache.ratis.proto.RaftProtos.RequestVoteReplyProto;
+import org.apache.ratis.proto.RaftProtos.RequestVoteRequestProto;
 import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.GroupInfoReply;
 import org.apache.ratis.protocol.GroupInfoRequest;
 import org.apache.ratis.protocol.GroupListReply;
 import org.apache.ratis.protocol.GroupListRequest;
 import org.apache.ratis.protocol.GroupManagementRequest;
-import org.apache.ratis.protocol.RaftClientMessage;
 import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftGroup;
@@ -42,27 +49,20 @@ import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.protocol.RaftGroupMemberId;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.RaftPeerId;
-import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
 import org.apache.ratis.protocol.SetConfigurationRequest;
-import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
 import org.apache.ratis.rpc.RpcType;
 import org.apache.ratis.server.DataStreamMap;
 import org.apache.ratis.server.RaftServer;
 import org.apache.ratis.server.impl.DataStreamServerImpl;
 import org.apache.ratis.server.impl.RaftServerTestUtil;
 import org.apache.ratis.server.impl.ServerFactory;
+import org.apache.ratis.statemachine.StateMachine;
 import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
-import org.apache.ratis.statemachine.impl.BaseStateMachine;
-import org.apache.ratis.statemachine.StateMachine.DataStream;
-import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
-import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.LifeCycle;
 import org.apache.ratis.util.NetUtils;
 import org.junit.Assert;
-import org.slf4j.Logger;
 
 import java.io.IOException;
-import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -73,108 +73,9 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.Executor;
 import java.util.concurrent.Executors;
-import java.util.concurrent.ThreadLocalRandom;
 import java.util.stream.Collectors;
 
 abstract class DataStreamBaseTest extends BaseTest {
-  static final int MODULUS = 23;
-
-  static byte pos2byte(int pos) {
-    return (byte) ('A' + pos%MODULUS);
-  }
-
-  static final ByteString MOCK = ByteString.copyFromUtf8("mock");
-
-  static ByteString bytesWritten2ByteString(long bytesWritten) {
-    return ByteString.copyFromUtf8("bytesWritten=" + bytesWritten);
-  }
-
-  private final Executor executor = Executors.newFixedThreadPool(16);
-
-  static class MultiDataStreamStateMachine extends BaseStateMachine {
-    private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
-
-    @Override
-    public CompletableFuture<DataStream> stream(RaftClientRequest request) {
-      final SingleDataStream s = new SingleDataStream(request);
-      streams.put(ClientInvocationId.valueOf(request), s);
-      return CompletableFuture.completedFuture(s);
-    }
-
-    SingleDataStream getSingleDataStream(RaftClientRequest request) {
-      return streams.get(ClientInvocationId.valueOf(request));
-    }
-  }
-
-  static class SingleDataStream implements DataStream {
-    private int byteWritten = 0;
-    private final RaftClientRequest writeRequest;
-    private int forcedPosition = 0;
-
-    final StateMachineDataChannel channel = new StateMachineDataChannel() {
-      @Override
-      public void force(boolean metadata) {
-        forcedPosition = byteWritten;
-      }
-
-      private volatile boolean open = true;
-
-      @Override
-      public int write(ByteBuffer src) {
-        if (!open) {
-          throw new IllegalStateException("Already closed");
-        }
-        final int remaining = src.remaining();
-        for(; src.remaining() > 0; ) {
-          Assert.assertEquals(pos2byte(byteWritten), src.get());
-          byteWritten += 1;
-        }
-        return remaining;
-      }
-
-      @Override
-      public boolean isOpen() {
-        return open;
-      }
-
-      @Override
-      public void close() {
-        open = false;
-      }
-    };
-
-      @Override
-      public StateMachineDataChannel getWritableByteChannel() {
-        return channel;
-      }
-
-      @Override
-      public CompletableFuture<?> cleanUp() {
-        try {
-          channel.close();
-        } catch (Throwable t) {
-          return JavaUtils.completeExceptionally(t);
-        }
-        return CompletableFuture.completedFuture(null);
-      }
-
-    SingleDataStream(RaftClientRequest request) {
-      this.writeRequest = request;
-    }
-
-    public int getByteWritten() {
-      return byteWritten;
-    }
-
-    public RaftClientRequest getWriteRequest() {
-      return writeRequest;
-    }
-
-    public int getForcedPosition() {
-      return forcedPosition;
-    }
-  }
-
   static class MyDivision implements RaftServer.Division {
     private final MultiDataStreamStateMachine stateMachine = new MultiDataStreamStateMachine();
     private final DataStreamMap streamMap;
@@ -240,18 +141,29 @@ abstract class DataStreamBaseTest extends BaseTest {
 
   private List<Server> servers;
   private RaftGroup raftGroup;
+  private final Executor executor = Executors.newFixedThreadPool(16);
 
   Server getPrimaryServer() {
     return servers.get(0);
   }
 
-  protected RaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
-    return new RaftServer() {
+  protected MyRaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
+    return new MyRaftServer(peer.getId(), properties);
+  }
+
+  static class MyRaftServer implements RaftServer {
+      private final RaftPeerId id;
+      private final RaftProperties properties;
       private final ConcurrentMap<RaftGroupId, MyDivision> divisions = new ConcurrentHashMap<>();
 
+      MyRaftServer(RaftPeerId id, RaftProperties properties) {
+        this.id = id;
+        this.properties = properties;
+      }
+
       @Override
       public RaftPeerId getId() {
-        return peer.getId();
+        return id;
       }
 
       @Override
@@ -301,14 +213,21 @@ abstract class DataStreamBaseTest extends BaseTest {
 
       @Override
       public CompletableFuture<RaftClientReply> submitClientRequestAsync(RaftClientRequest request) {
-        final MultiDataStreamStateMachine stateMachine = getDivision(request.getRaftGroupId()).getStateMachine();
-        final SingleDataStream stream = stateMachine.getSingleDataStream(request);
-        Assert.assertFalse(stream.getWritableByteChannel().isOpen());
-        return CompletableFuture.completedFuture(RaftClientReply.newBuilder()
+        final MyDivision d = getDivision(request.getRaftGroupId());
+        return d.getDataStreamMap()
+            .remove(ClientInvocationId.valueOf(request))
+            .thenApply(StateMachine.DataStream::getWritableByteChannel)
+            .thenApply(channel -> buildRaftClientReply(request, channel));
+      }
+
+      static RaftClientReply buildRaftClientReply(RaftClientRequest request, StateMachineDataChannel channel) {
+        Assert.assertTrue(channel instanceof DataChannel);
+        final DataChannel dataChannel = (DataChannel) channel;
+        return RaftClientReply.newBuilder()
             .setRequest(request)
             .setSuccess()
-            .setMessage(() -> bytesWritten2ByteString(stream.getByteWritten()))
-            .build());
+            .setMessage(() -> DataStreamTestUtils.bytesWritten2ByteString(dataChannel.getBytesWritten()))
+            .build();
       }
 
       @Override
@@ -373,7 +292,6 @@ abstract class DataStreamBaseTest extends BaseTest {
       public LifeCycle.State getLifeCycleState() {
         return null;
       }
-    };
   }
 
 
@@ -465,7 +383,6 @@ abstract class DataStreamBaseTest extends BaseTest {
     }
   }
 
-
   void runTestMockCluster(ClientId clientId, int bufferSize, int bufferNum,
       Exception expectedException, Exception headerException)
       throws IOException {
@@ -491,42 +408,9 @@ abstract class DataStreamBaseTest extends BaseTest {
     }
   }
 
-  static int writeAndAssertReplies(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
-    final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
-    final List<Integer> sizes = new ArrayList<>();
-
-    //send data
-    final int halfBufferSize = bufferSize/2;
-    int dataSize = 0;
-    for(int i = 0; i < bufferNum; i++) {
-      final int size = halfBufferSize + ThreadLocalRandom.current().nextInt(halfBufferSize);
-      sizes.add(size);
-
-      final ByteBuffer bf = initBuffer(dataSize, size);
-      futures.add(out.writeAsync(bf, i == bufferNum - 1));
-      dataSize += size;
-    }
-
-    { // check header
-      final DataStreamReply reply = out.getHeaderFuture().join();
-      Assert.assertTrue(reply.isSuccess());
-      Assert.assertEquals(0, reply.getBytesWritten());
-      Assert.assertEquals(reply.getType(), Type.STREAM_HEADER);
-    }
-
-    // check writeAsync requests
-    for(int i = 0; i < futures.size(); i++) {
-      final DataStreamReply reply = futures.get(i).join();
-      Assert.assertTrue(reply.isSuccess());
-      Assert.assertEquals(sizes.get(i).longValue(), reply.getBytesWritten());
-      Assert.assertEquals(reply.getType(), i == futures.size() - 1 ? Type.STREAM_DATA_SYNC : Type.STREAM_DATA);
-    }
-    return dataSize;
-  }
-
   CompletableFuture<RaftClientReply> runTestDataStream(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
     LOG.info("start Stream{}", out.getHeader().getCallId());
-    final int bytesWritten = writeAndAssertReplies(out, bufferSize, bufferNum);
+    final int bytesWritten = DataStreamTestUtils.writeAndAssertReplies(out, bufferSize, bufferNum);
     try {
       for (Server s : servers) {
         assertHeader(s, out.getHeader(), bytesWritten);
@@ -535,32 +419,7 @@ abstract class DataStreamBaseTest extends BaseTest {
       throw new CompletionException(e);
     }
 
-    return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten));
-  }
-
-  static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
-      long bytesWritten) {
-    // Test close idempotent
-    Assert.assertSame(dataStreamReply, out.closeAsync().join());
-    testFailureCase("writeAsync should fail",
-        () -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
-        CompletionException.class, (Logger)null, AlreadyClosedException.class);
-
-    try {
-      final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
-          ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
-      assertRaftClientMessage(out.getHeader(), reply);
-      if (reply.isSuccess()) {
-        final ByteString bytes = reply.getMessage().getContent();
-        if (!bytes.equals(MOCK)) {
-          Assert.assertEquals(bytesWritten2ByteString(bytesWritten), bytes);
-        }
-      }
-
-      return CompletableFuture.completedFuture(reply);
-    } catch (Throwable t) {
-      return JavaUtils.completeExceptionally(t);
-    }
+    return out.closeAsync().thenCompose(reply -> DataStreamTestUtils.assertCloseReply(out, reply, bytesWritten));
   }
 
   void assertHeader(Server server, RaftClientRequest header, int dataSize) throws Exception {
@@ -570,34 +429,14 @@ abstract class DataStreamBaseTest extends BaseTest {
 
     // check stream
     final MyDivision d = server.getDivision(header.getRaftGroupId());
-    Assert.assertNotNull(d.getDataStreamMap().remove(ClientInvocationId.valueOf(header)));
     final SingleDataStream stream = d.getStateMachine().getSingleDataStream(header);
-    Assert.assertEquals(dataSize, stream.getByteWritten());
-    Assert.assertEquals(dataSize, stream.getForcedPosition());
+    final DataChannel channel = d.getStateMachine().getSingleDataStream(header).getWritableByteChannel();
+    Assert.assertEquals(dataSize, channel.getBytesWritten());
+    Assert.assertEquals(dataSize, channel.getForcedPosition());
 
     // check writeRequest
     final RaftClientRequest writeRequest = stream.getWriteRequest();
     Assert.assertEquals(RaftClientRequest.dataStreamRequestType(), writeRequest.getType());
-    assertRaftClientMessage(header, writeRequest);
-  }
-
-  static void assertRaftClientMessage(RaftClientRequest expected, RaftClientMessage computed) {
-    Assert.assertNotNull(computed);
-    Assert.assertEquals(expected.getClientId(), computed.getClientId());
-    Assert.assertEquals(expected.getServerId(), computed.getServerId());
-    Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
-    Assert.assertEquals(expected.getCallId(), computed.getCallId());
-  }
-
-  static ByteBuffer initBuffer(int offset, int size) {
-    final ByteBuffer buffer = ByteBuffer.allocateDirect(size);
-    final int length = buffer.capacity();
-    buffer.position(0).limit(length);
-    for (int j = 0; j < length; j++) {
-      buffer.put(pos2byte(offset + j));
-    }
-    buffer.flip();
-    Assert.assertEquals(length, buffer.remaining());
-    return buffer;
+    DataStreamTestUtils.assertRaftClientMessage(header, writeRequest);
   }
 }
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index bf1210e..4020761 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -21,7 +21,7 @@ import org.apache.ratis.BaseTest;
 import org.apache.ratis.MiniRaftCluster;
 import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
-import org.apache.ratis.datastream.DataStreamBaseTest.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
 import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
 import org.apache.ratis.protocol.ClientId;
@@ -65,7 +65,7 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
 
       // send a stream request
       try(final DataStreamOutputImpl out = (DataStreamOutputImpl) client.getDataStreamApi().stream()) {
-        DataStreamBaseTest.writeAndAssertReplies(out, 1000, 10);
+        DataStreamTestUtils.writeAndAssertReplies(out, 1000, 10);
         callId = out.getHeader().getCallId();
       }
     }
@@ -83,7 +83,7 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
       if (entry.hasStateMachineLogEntry()) {
         final StateMachineLogEntryProto stateMachineEntry = entry.getStateMachineLogEntry();
         if (stateMachineEntry.getCallId() == callId) {
-          if (clientId.equals(ClientId.valueOf(stateMachineEntry.getClientId()))) {
+          if (clientId.toByteString().equals(stateMachineEntry.getClientId())) {
             return entry;
           }
         }
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
new file mode 100644
index 0000000..a834cd8
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.datastream;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.client.impl.ClientProtoUtils;
+import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
+import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
+import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
+import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
+import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.protocol.DataStreamReply;
+import org.apache.ratis.protocol.RaftClientMessage;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.protocol.RaftClientRequest;
+import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
+import org.apache.ratis.statemachine.StateMachine.DataStream;
+import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
+import org.apache.ratis.statemachine.impl.BaseStateMachine;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.util.JavaUtils;
+import org.junit.Assert;
+import org.slf4j.Logger;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ThreadLocalRandom;
+
+public interface DataStreamTestUtils {
+  ByteString MOCK = ByteString.copyFromUtf8("mock");
+  int MODULUS = 23;
+
+  static byte pos2byte(int pos) {
+    return (byte) ('A' + pos % MODULUS);
+  }
+
+  static ByteString bytesWritten2ByteString(long bytesWritten) {
+    return ByteString.copyFromUtf8("bytesWritten=" + bytesWritten);
+  }
+
+  class MultiDataStreamStateMachine extends BaseStateMachine {
+    private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
+
+    @Override
+    public CompletableFuture<DataStream> stream(RaftClientRequest request) {
+      final SingleDataStream s = new SingleDataStream(request);
+      streams.put(ClientInvocationId.valueOf(request), s);
+      return CompletableFuture.completedFuture(s);
+    }
+
+    @Override
+    public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
+      final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
+      s.setLogEntry(entry);
+      return CompletableFuture.completedFuture(null);
+    }
+
+    SingleDataStream getSingleDataStream(RaftClientRequest request) {
+      return getSingleDataStream(ClientInvocationId.valueOf(request));
+    }
+
+    SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
+      return streams.get(invocationId);
+    }
+  }
+
+  class SingleDataStream implements DataStream {
+    private final RaftClientRequest writeRequest;
+    private final DataChannel channel = new DataChannel();
+    private volatile LogEntryProto logEntry;
+
+    SingleDataStream(RaftClientRequest request) {
+      this.writeRequest = request;
+    }
+
+    @Override
+    public DataChannel getWritableByteChannel() {
+      return channel;
+    }
+
+    @Override
+    public CompletableFuture<?> cleanUp() {
+      try {
+        channel.close();
+      } catch (Throwable t) {
+        return JavaUtils.completeExceptionally(t);
+      }
+      return CompletableFuture.completedFuture(null);
+    }
+
+    void setLogEntry(LogEntryProto logEntry) {
+      this.logEntry = logEntry;
+    }
+
+    LogEntryProto getLogEntry() {
+      return logEntry;
+    }
+
+    RaftClientRequest getWriteRequest() {
+      return writeRequest;
+    }
+  }
+
+  class DataChannel implements StateMachineDataChannel {
+    private volatile boolean open = true;
+    private int bytesWritten = 0;
+    private int forcedPosition = 0;
+
+    int getBytesWritten() {
+      return bytesWritten;
+    }
+
+    int getForcedPosition() {
+      return forcedPosition;
+    }
+
+    @Override
+    public void force(boolean metadata) {
+      forcedPosition = bytesWritten;
+    }
+
+    @Override
+    public int write(ByteBuffer src) {
+      if (!open) {
+        throw new IllegalStateException("Already closed");
+      }
+      final int remaining = src.remaining();
+      for (; src.remaining() > 0; ) {
+        Assert.assertEquals(pos2byte(bytesWritten), src.get());
+        bytesWritten += 1;
+      }
+      return remaining;
+    }
+
+    @Override
+    public boolean isOpen() {
+      return open;
+    }
+
+    @Override
+    public void close() {
+      open = false;
+    }
+  }
+
+  static int writeAndAssertReplies(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
+    final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
+    final List<Integer> sizes = new ArrayList<>();
+
+    //send data
+    final int halfBufferSize = bufferSize / 2;
+    int dataSize = 0;
+    for (int i = 0; i < bufferNum; i++) {
+      final int size = halfBufferSize + ThreadLocalRandom.current().nextInt(halfBufferSize);
+      sizes.add(size);
+
+      final ByteBuffer bf = initBuffer(dataSize, size);
+      futures.add(out.writeAsync(bf, i == bufferNum - 1));
+      dataSize += size;
+    }
+
+    { // check header
+      final DataStreamReply reply = out.getHeaderFuture().join();
+      Assert.assertTrue(reply.isSuccess());
+      Assert.assertEquals(0, reply.getBytesWritten());
+      Assert.assertEquals(reply.getType(), Type.STREAM_HEADER);
+    }
+
+    // check writeAsync requests
+    for (int i = 0; i < futures.size(); i++) {
+      final DataStreamReply reply = futures.get(i).join();
+      Assert.assertTrue(reply.isSuccess());
+      Assert.assertEquals(sizes.get(i).longValue(), reply.getBytesWritten());
+      Assert.assertEquals(reply.getType(), i == futures.size() - 1 ? Type.STREAM_DATA_SYNC : Type.STREAM_DATA);
+    }
+    return dataSize;
+  }
+
+  static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
+      long bytesWritten) {
+    // Test close idempotent
+    Assert.assertSame(dataStreamReply, out.closeAsync().join());
+    BaseTest.testFailureCase("writeAsync should fail",
+        () -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
+        CompletionException.class, (Logger) null, AlreadyClosedException.class);
+
+    try {
+      final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
+          ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
+      assertRaftClientMessage(out.getHeader(), reply);
+      if (reply.isSuccess()) {
+        final ByteString bytes = reply.getMessage().getContent();
+        if (!bytes.equals(MOCK)) {
+          Assert.assertEquals(bytesWritten2ByteString(bytesWritten), bytes);
+        }
+      }
+
+      return CompletableFuture.completedFuture(reply);
+    } catch (Throwable t) {
+      return JavaUtils.completeExceptionally(t);
+    }
+  }
+
+  static void assertRaftClientMessage(RaftClientMessage expected, RaftClientMessage computed) {
+    Assert.assertNotNull(computed);
+    Assert.assertEquals(expected.getClientId(), computed.getClientId());
+    Assert.assertEquals(expected.getServerId(), computed.getServerId());
+    Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
+    Assert.assertEquals(expected.getCallId(), computed.getCallId());
+  }
+
+  static ByteBuffer initBuffer(int offset, int size) {
+    final ByteBuffer buffer = ByteBuffer.allocateDirect(size);
+    final int length = buffer.capacity();
+    buffer.position(0).limit(length);
+    for (int j = 0; j < length; j++) {
+      buffer.put(pos2byte(offset + j));
+    }
+    buffer.flip();
+    Assert.assertEquals(length, buffer.remaining());
+    return buffer;
+  }
+}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
index 1655a0c..e9ce41e 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
@@ -68,7 +68,7 @@ public class TestDataStreamNetty extends DataStreamBaseTest {
   }
 
   @Override
-  protected RaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
+  protected MyRaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
     final RaftProperties p = new RaftProperties(properties);
     NettyConfigKeys.DataStream.setPort(p, NetUtils.createSocketAddr(peer.getDataStreamAddress()).getPort());
     return super.newRaftServer(peer, p);
@@ -115,7 +115,7 @@ public class TestDataStreamNetty extends DataStreamBaseTest {
               if (isLeader) {
                 final RaftClientReply.Builder b = RaftClientReply.newBuilder().setRequest(r);
                 reply = leaderException != null? b.setException(leaderException).build()
-                    : b.setSuccess().setMessage(() -> MOCK).build();
+                    : b.setSuccess().setMessage(() -> DataStreamTestUtils.MOCK).build();
               } else {
                 final RaftGroupMemberId memberId = RaftGroupMemberId.valueOf(peerId, groupId);
                 final NotLeaderException notLeaderException = new NotLeaderException(memberId, suggestedLeader, null);