You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by ru...@apache.org on 2020/11/24 01:44:16 UTC

[incubator-ratis] branch master updated: RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link (#294)

This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b83c99  RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link (#294)
2b83c99 is described below

commit 2b83c9935baaac6c4dbd96ddd0cfc4d0c944df6a
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Tue Nov 24 09:44:10 2020 +0800

    RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link (#294)
    
    * RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link.
    
    * decodeDataStreamReplyByteBuffer should copy ByteBuf.
---
 .../apache/ratis/client/impl/ClientProtoUtils.java |  7 +++
 .../apache/ratis/netty/NettyDataStreamUtils.java   |  2 +-
 .../ratis/netty/server/DataStreamManagement.java   |  9 ++--
 .../raftlog/segmented/SegmentedRaftLogWorker.java  |  7 ++-
 .../apache/ratis/statemachine/StateMachine.java    |  5 ++
 .../datastream/DataStreamAsyncClusterTests.java    | 58 +++++++++++++++++-----
 .../ratis/datastream/DataStreamBaseTest.java       |  7 ++-
 .../ratis/datastream/DataStreamClusterTests.java   |  8 +++
 .../ratis/datastream/DataStreamTestUtils.java      | 53 +++++++++++++++-----
 9 files changed, 119 insertions(+), 37 deletions(-)

diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
index c0591af..ab79f4a 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
@@ -17,7 +17,9 @@
  */
 package org.apache.ratis.client.impl;
 
+import java.nio.ByteBuffer;
 import java.util.Optional;
+
 import org.apache.ratis.protocol.*;
 import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
 import org.apache.ratis.protocol.exceptions.DataStreamException;
@@ -27,6 +29,7 @@ import org.apache.ratis.protocol.exceptions.NotReplicatedException;
 import org.apache.ratis.protocol.exceptions.RaftException;
 import org.apache.ratis.protocol.exceptions.StateMachineException;
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.ratis.proto.RaftProtos.*;
 import org.apache.ratis.util.ProtoUtils;
 
@@ -263,6 +266,10 @@ public interface ClientProtoUtils {
     return b.build();
   }
 
+  static RaftClientReply toRaftClientReply(ByteBuffer buffer) throws InvalidProtocolBufferException {
+    return toRaftClientReply(RaftClientReplyProto.parseFrom(buffer));
+  }
+
   static RaftClientReply toRaftClientReply(RaftClientReplyProto replyProto) {
     final RaftRpcReplyProto rp = replyProto.getRpcReply();
     final RaftGroupMemberId serverMemberId = ProtoUtils.toRaftGroupMemberId(rp.getReplyId(), rp.getRaftGroupId());
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
index 33d77fa..a3f850d 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
@@ -101,7 +101,7 @@ public interface NettyDataStreamUtils {
         .map(header -> checkHeader(header, buf))
         .map(header -> DataStreamReplyByteBuffer.newBuilder()
             .setDataStreamReplyHeader(header)
-            .setBuffer(decodeData(buf, header, ByteBuf::nioBuffer))
+            .setBuffer(decodeData(buf, header, b -> b.copy().nioBuffer()))
             .build())
         .orElse(null);
   }
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
index c2a9e20..a78bd1f 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
@@ -23,7 +23,6 @@ import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
 import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
 import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
 import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.protocol.ClientInvocationId;
@@ -401,13 +400,13 @@ public class DataStreamManagement {
   static RaftClientReply getRaftClientReply(DataStreamReply dataStreamReply) {
     if (dataStreamReply instanceof DataStreamReplyByteBuffer) {
       try {
-        return ClientProtoUtils.toRaftClientReply(
-            RaftClientReplyProto.parseFrom(((DataStreamReplyByteBuffer) dataStreamReply).slice()));
+        return ClientProtoUtils.toRaftClientReply(((DataStreamReplyByteBuffer) dataStreamReply).slice());
       } catch (InvalidProtocolBufferException e) {
-        throw new IllegalStateException("Failed to decode RaftClientReply");
+        throw new IllegalStateException("Failed to parse " + JavaUtils.getClassSimpleName(dataStreamReply.getClass())
+            + ": reply is " + dataStreamReply, e);
       }
     } else {
-      throw new IllegalStateException("Unexpected reply type");
+      throw new IllegalStateException("Unexpected " + dataStreamReply.getClass() + ": reply is " + dataStreamReply);
     }
   }
 
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
index 6069469..63dd439 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
@@ -37,6 +37,7 @@ import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLogCache.Truncatio
 import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLog.Task;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
 import org.apache.ratis.statemachine.StateMachine;
+import org.apache.ratis.statemachine.StateMachine.DataStream;
 import org.apache.ratis.util.*;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -463,8 +464,10 @@ class SegmentedRaftLogWorker implements Runnable {
       if (this.entry == entry) {
         final StateMachineLogEntryProto proto = entry.hasStateMachineLogEntry()? entry.getStateMachineLogEntry(): null;
         if (stateMachine != null && proto != null && proto.getType() == StateMachineLogEntryProto.Type.DATASTREAM) {
-          this.stateMachineFuture = server.getDataStreamMap().remove(ClientInvocationId.valueOf(proto))
-              .thenApply(stream -> stateMachine.data().link(stream, entry));
+          final ClientInvocationId invocationId = ClientInvocationId.valueOf(proto);
+          final CompletableFuture<DataStream> removed = server.getDataStreamMap().remove(invocationId);
+          this.stateMachineFuture = removed == null? stateMachine.data().link(null, entry)
+              : removed.thenApply(stream -> stateMachine.data().link(stream, entry));
         } else {
           this.stateMachineFuture = null;
         }
diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
index 74d4df5..32c5380 100644
--- a/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
+++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
@@ -93,7 +93,12 @@ public interface StateMachine extends Closeable {
 
     /**
      * Link asynchronously the given stream with the given log entry.
+     * The given stream can be null if it is unavailable due to errors.
+     * In such case, the state machine may either recover the data by itself
+     * or complete the returned future exceptionally.
      *
+     * @param stream the stream, which can possibly be null, to be linked.
+     * @param entry the log entry to be linked.
      * @return a future for the link task.
      */
     default CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
index 28aa30b..3c98523 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
@@ -21,7 +21,14 @@ import org.apache.ratis.MiniRaftCluster;
 import org.apache.ratis.RaftTestUtil;
 import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.SingleDataStream;
+import org.apache.ratis.proto.RaftProtos.ReplicationLevel;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.protocol.RaftPeerId;
 import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.impl.RaftServerImpl;
+import org.apache.ratis.server.impl.RaftServerProxy;
 import org.apache.ratis.util.CollectionUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -50,32 +57,57 @@ public abstract class DataStreamAsyncClusterTests<CLUSTER extends MiniRaftCluste
 
   void runTestDataStream(CLUSTER cluster) throws Exception {
     RaftTestUtil.waitForLeader(cluster);
-    final List<CompletableFuture<Void>> futures = new ArrayList<>();
-    futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, 5, 10, 1_000_000, 10), executor));
-    futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, 2, 20, 1_000, 10_000), executor));
-    futures.forEach(CompletableFuture::join);
+
+    final List<CompletableFuture<Long>> futures = new ArrayList<>();
+    futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, 5, 10, 1_000_000, 10), executor));
+    futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, 2, 20, 1_000, 10_000), executor));
+    final long maxIndex = futures.stream()
+        .map(CompletableFuture::join)
+        .max(Long::compareTo)
+        .orElseThrow(IllegalStateException::new);
+
+    // wait for all servers to catch up
+    try (RaftClient client = cluster.createClient()) {
+      client.async().watch(maxIndex, ReplicationLevel.ALL).join();
+    }
+    // assert all streams are linked
+    for (RaftServerProxy proxy : cluster.getServers()) {
+      final RaftServerImpl impl = proxy.getImpl(cluster.getGroupId());
+      final MultiDataStreamStateMachine stateMachine = (MultiDataStreamStateMachine) impl.getStateMachine();
+      for (SingleDataStream s : stateMachine.getStreams()) {
+        Assert.assertNotNull(s.getLogEntry());
+      }
+    }
   }
 
-  void runTestDataStream(CLUSTER cluster, int numClients, int numStreams, int bufferSize, int bufferNum) {
-    final List<CompletableFuture<Void>> futures = new ArrayList<>();
+  Long runTestDataStream(CLUSTER cluster, int numClients, int numStreams, int bufferSize, int bufferNum) {
+    final List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (int j = 0; j < numClients; j++) {
-      futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, numStreams, bufferSize, bufferNum), executor));
+      futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, numStreams, bufferSize, bufferNum), executor));
     }
     Assert.assertEquals(numClients, futures.size());
-    futures.forEach(CompletableFuture::join);
+    return futures.stream()
+        .map(CompletableFuture::join)
+        .max(Long::compareTo)
+        .orElseThrow(IllegalStateException::new);
   }
 
-  void runTestDataStream(CLUSTER cluster, int numStreams, int bufferSize, int bufferNum) {
+  long runTestDataStream(CLUSTER cluster, int numStreams, int bufferSize, int bufferNum) {
     final Iterable<RaftServer> servers = CollectionUtils.as(cluster.getServers(), s -> s);
-    final List<CompletableFuture<Void>> futures = new ArrayList<>();
+    final RaftPeerId leader = cluster.getLeader().getId();
+    final List<CompletableFuture<RaftClientReply>> futures = new ArrayList<>();
     try(RaftClient client = cluster.createClient()) {
       for (int i = 0; i < numStreams; i++) {
         final DataStreamOutputImpl out = (DataStreamOutputImpl) client.getDataStreamApi().stream();
-        futures.add(CompletableFuture.runAsync(() -> DataStreamTestUtils.writeAndCloseAndAssertReplies(
-            servers, out, bufferSize, bufferNum), executor));
+        futures.add(CompletableFuture.supplyAsync(() -> DataStreamTestUtils.writeAndCloseAndAssertReplies(
+            servers, leader, out, bufferSize, bufferNum).join(), executor));
       }
       Assert.assertEquals(numStreams, futures.size());
-      futures.forEach(CompletableFuture::join);
+      return futures.stream()
+          .map(CompletableFuture::join)
+          .map(RaftClientReply::getLogIndex)
+          .max(Long::compareTo)
+          .orElseThrow(IllegalStateException::new);
     } catch (IOException e) {
       throw new CompletionException(e);
     }
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
index c08e1bf..0f8d314 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
@@ -30,7 +30,6 @@ import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto;
 import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto;
 import org.apache.ratis.proto.RaftProtos.InstallSnapshotReplyProto;
 import org.apache.ratis.proto.RaftProtos.InstallSnapshotRequestProto;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
 import org.apache.ratis.proto.RaftProtos.RequestVoteReplyProto;
 import org.apache.ratis.proto.RaftProtos.RequestVoteRequestProto;
 import org.apache.ratis.protocol.ClientId;
@@ -352,14 +351,14 @@ abstract class DataStreamBaseTest extends BaseTest {
       if (headerException != null) {
         final DataStreamReply headerReply = out.getHeaderFuture().join();
         Assert.assertFalse(headerReply.isSuccess());
-        final RaftClientReply clientReply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
-            ((DataStreamReplyByteBuffer)headerReply).slice()));
+        final RaftClientReply clientReply = ClientProtoUtils.toRaftClientReply(
+            ((DataStreamReplyByteBuffer)headerReply).slice());
         Assert.assertTrue(clientReply.getException().getMessage().contains(headerException.getMessage()));
         return;
       }
 
       final RaftClientReply clientReply = DataStreamTestUtils.writeAndCloseAndAssertReplies(
-          CollectionUtils.as(servers, Server::getRaftServer), out, bufferSize, bufferNum).join();
+          CollectionUtils.as(servers, Server::getRaftServer), null, out, bufferSize, bufferNum).join();
       if (expectedException != null) {
         Assert.assertFalse(clientReply.isSuccess());
         Assert.assertTrue(clientReply.getException().getMessage().contains(
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index 4020761..28dbc87 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -21,19 +21,23 @@ import org.apache.ratis.BaseTest;
 import org.apache.ratis.MiniRaftCluster;
 import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
 import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
 import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.protocol.RaftGroup;
 import org.apache.ratis.protocol.RaftPeer;
+import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.impl.RaftServerImpl;
 import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.server.raftlog.RaftLog;
+import org.apache.ratis.util.TimeDuration;
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.Collection;
+import java.util.concurrent.TimeUnit;
 
 import static org.apache.ratis.RaftTestUtil.waitForLeader;
 
@@ -41,6 +45,10 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
     implements MiniRaftCluster.Factory.Get<CLUSTER> {
   {
     setStateMachine(MultiDataStreamStateMachine.class);
+
+    // Avoid changing leader
+    RaftServerConfigKeys.Rpc.setTimeoutMin(getProperties(), TimeDuration.valueOf(2, TimeUnit.SECONDS));
+    RaftServerConfigKeys.Rpc.setTimeoutMax(getProperties(), TimeDuration.valueOf(3, TimeUnit.SECONDS));
   }
 
   public static final int NUM_SERVERS = 3;
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
index d8bd383..2f695a0 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -24,16 +24,19 @@ import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
 import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
 import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
 import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.DataStreamReply;
+import org.apache.ratis.protocol.Message;
 import org.apache.ratis.protocol.RaftClientMessage;
 import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.protocol.RaftClientRequest;
+import org.apache.ratis.protocol.RaftPeerId;
 import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
 import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.impl.ServerProtoUtils;
 import org.apache.ratis.statemachine.StateMachine.DataStream;
 import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
+import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.statemachine.impl.BaseStateMachine;
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.util.JavaUtils;
@@ -43,7 +46,10 @@ import org.slf4j.LoggerFactory;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.ConcurrentHashMap;
@@ -76,11 +82,23 @@ public interface DataStreamTestUtils {
 
     @Override
     public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
-      final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
-      s.setLogEntry(entry);
+      LOG.info("link {}", stream);
+      if (stream == null) {
+        return JavaUtils.completeExceptionally(new IllegalStateException("Null stream: entry=" + entry));
+      }
+      ((SingleDataStream)stream).setLogEntry(entry);
       return CompletableFuture.completedFuture(null);
     }
 
+    @Override
+    public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
+      final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+      updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
+      final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
+      final ByteString bytesWritten = bytesWritten2ByteString(s.getWritableByteChannel().getBytesWritten());
+      return CompletableFuture.completedFuture(() -> bytesWritten);
+    }
+
     SingleDataStream getSingleDataStream(RaftClientRequest request) {
       return getSingleDataStream(ClientInvocationId.valueOf(request));
     }
@@ -88,6 +106,10 @@ public interface DataStreamTestUtils {
     SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
       return streams.get(invocationId);
     }
+
+    Collection<SingleDataStream> getStreams() {
+      return streams.values();
+    }
   }
 
   class SingleDataStream implements DataStream {
@@ -125,6 +147,12 @@ public interface DataStreamTestUtils {
     RaftClientRequest getWriteRequest() {
       return writeRequest;
     }
+
+    @Override
+    public String toString() {
+      return JavaUtils.getClassSimpleName(getClass()) + ": writeRequest=" + writeRequest
+          + ", logEntry=" + ServerProtoUtils.toString(logEntry);
+    }
   }
 
   class DataChannel implements StateMachineDataChannel {
@@ -203,7 +231,7 @@ public interface DataStreamTestUtils {
   }
 
   static CompletableFuture<RaftClientReply> writeAndCloseAndAssertReplies(
-      Iterable<RaftServer> servers, DataStreamOutputImpl out, int bufferSize, int bufferNum) {
+      Iterable<RaftServer> servers, RaftPeerId leader, DataStreamOutputImpl out, int bufferSize, int bufferNum) {
     LOG.info("start Stream{}", out.getHeader().getCallId());
     final int bytesWritten = writeAndAssertReplies(out, bufferSize, bufferNum);
     try {
@@ -213,8 +241,9 @@ public interface DataStreamTestUtils {
     } catch (Throwable e) {
       throw new CompletionException(e);
     }
+    LOG.info("Stream{}: bytesWritten={}", out.getHeader().getCallId(), bytesWritten);
 
-    return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten));
+    return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten, leader));
   }
 
   static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize) throws Exception {
@@ -231,21 +260,21 @@ public interface DataStreamTestUtils {
     // check writeRequest
     final RaftClientRequest writeRequest = stream.getWriteRequest();
     Assert.assertEquals(RaftClientRequest.dataStreamRequestType(), writeRequest.getType());
-    assertRaftClientMessage(header, writeRequest);
+    assertRaftClientMessage(header, null, writeRequest);
   }
 
   static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
-      long bytesWritten) {
+      long bytesWritten, RaftPeerId leader) {
     // Test close idempotent
     Assert.assertSame(dataStreamReply, out.closeAsync().join());
     BaseTest.testFailureCase("writeAsync should fail",
         () -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
         CompletionException.class, (Logger) null, AlreadyClosedException.class);
 
+    final DataStreamReplyByteBuffer buffer = (DataStreamReplyByteBuffer) dataStreamReply;
     try {
-      final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
-          ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
-      assertRaftClientMessage(out.getHeader(), reply);
+      final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(buffer.slice());
+      assertRaftClientMessage(out.getHeader(), leader, reply);
       if (reply.isSuccess()) {
         final ByteString bytes = reply.getMessage().getContent();
         if (!bytes.equals(MOCK)) {
@@ -259,10 +288,10 @@ public interface DataStreamTestUtils {
     }
   }
 
-  static void assertRaftClientMessage(RaftClientMessage expected, RaftClientMessage computed) {
+  static void assertRaftClientMessage(RaftClientMessage expected, RaftPeerId expectedServerId, RaftClientMessage computed) {
     Assert.assertNotNull(computed);
     Assert.assertEquals(expected.getClientId(), computed.getClientId());
-    Assert.assertEquals(expected.getServerId(), computed.getServerId());
+    Assert.assertEquals(Optional.ofNullable(expectedServerId).orElseGet(expected::getServerId), computed.getServerId());
     Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
     Assert.assertEquals(expected.getCallId(), computed.getCallId());
   }