You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by ru...@apache.org on 2020/11/24 01:44:16 UTC
[incubator-ratis] branch master updated: RATIS-1171. Allow null for
the stream parameter in StateMachine.DataApi.link (#294)
This is an automated email from the ASF dual-hosted git repository.
runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git
The following commit(s) were added to refs/heads/master by this push:
new 2b83c99 RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link (#294)
2b83c99 is described below
commit 2b83c9935baaac6c4dbd96ddd0cfc4d0c944df6a
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Tue Nov 24 09:44:10 2020 +0800
RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link (#294)
* RATIS-1171. Allow null for the stream parameter in StateMachine.DataApi.link.
* decodeDataStreamReplyByteBuffer should copy ByteBuf.
---
.../apache/ratis/client/impl/ClientProtoUtils.java | 7 +++
.../apache/ratis/netty/NettyDataStreamUtils.java | 2 +-
.../ratis/netty/server/DataStreamManagement.java | 9 ++--
.../raftlog/segmented/SegmentedRaftLogWorker.java | 7 ++-
.../apache/ratis/statemachine/StateMachine.java | 5 ++
.../datastream/DataStreamAsyncClusterTests.java | 58 +++++++++++++++++-----
.../ratis/datastream/DataStreamBaseTest.java | 7 ++-
.../ratis/datastream/DataStreamClusterTests.java | 8 +++
.../ratis/datastream/DataStreamTestUtils.java | 53 +++++++++++++++-----
9 files changed, 119 insertions(+), 37 deletions(-)
diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
index c0591af..ab79f4a 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
@@ -17,7 +17,9 @@
*/
package org.apache.ratis.client.impl;
+import java.nio.ByteBuffer;
import java.util.Optional;
+
import org.apache.ratis.protocol.*;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.protocol.exceptions.DataStreamException;
@@ -27,6 +29,7 @@ import org.apache.ratis.protocol.exceptions.NotReplicatedException;
import org.apache.ratis.protocol.exceptions.RaftException;
import org.apache.ratis.protocol.exceptions.StateMachineException;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.ratis.proto.RaftProtos.*;
import org.apache.ratis.util.ProtoUtils;
@@ -263,6 +266,10 @@ public interface ClientProtoUtils {
return b.build();
}
+ static RaftClientReply toRaftClientReply(ByteBuffer buffer) throws InvalidProtocolBufferException {
+ return toRaftClientReply(RaftClientReplyProto.parseFrom(buffer));
+ }
+
static RaftClientReply toRaftClientReply(RaftClientReplyProto replyProto) {
final RaftRpcReplyProto rp = replyProto.getRpcReply();
final RaftGroupMemberId serverMemberId = ProtoUtils.toRaftGroupMemberId(rp.getReplyId(), rp.getRaftGroupId());
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
index 33d77fa..a3f850d 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
@@ -101,7 +101,7 @@ public interface NettyDataStreamUtils {
.map(header -> checkHeader(header, buf))
.map(header -> DataStreamReplyByteBuffer.newBuilder()
.setDataStreamReplyHeader(header)
- .setBuffer(decodeData(buf, header, ByteBuf::nioBuffer))
+ .setBuffer(decodeData(buf, header, b -> b.copy().nioBuffer()))
.build())
.orElse(null);
}
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
index c2a9e20..a78bd1f 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
@@ -23,7 +23,6 @@ import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.ClientInvocationId;
@@ -401,13 +400,13 @@ public class DataStreamManagement {
static RaftClientReply getRaftClientReply(DataStreamReply dataStreamReply) {
if (dataStreamReply instanceof DataStreamReplyByteBuffer) {
try {
- return ClientProtoUtils.toRaftClientReply(
- RaftClientReplyProto.parseFrom(((DataStreamReplyByteBuffer) dataStreamReply).slice()));
+ return ClientProtoUtils.toRaftClientReply(((DataStreamReplyByteBuffer) dataStreamReply).slice());
} catch (InvalidProtocolBufferException e) {
- throw new IllegalStateException("Failed to decode RaftClientReply");
+ throw new IllegalStateException("Failed to parse " + JavaUtils.getClassSimpleName(dataStreamReply.getClass())
+ + ": reply is " + dataStreamReply, e);
}
} else {
- throw new IllegalStateException("Unexpected reply type");
+ throw new IllegalStateException("Unexpected " + dataStreamReply.getClass() + ": reply is " + dataStreamReply);
}
}
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
index 6069469..63dd439 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
@@ -37,6 +37,7 @@ import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLogCache.Truncatio
import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLog.Task;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.statemachine.StateMachine;
+import org.apache.ratis.statemachine.StateMachine.DataStream;
import org.apache.ratis.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -463,8 +464,10 @@ class SegmentedRaftLogWorker implements Runnable {
if (this.entry == entry) {
final StateMachineLogEntryProto proto = entry.hasStateMachineLogEntry()? entry.getStateMachineLogEntry(): null;
if (stateMachine != null && proto != null && proto.getType() == StateMachineLogEntryProto.Type.DATASTREAM) {
- this.stateMachineFuture = server.getDataStreamMap().remove(ClientInvocationId.valueOf(proto))
- .thenApply(stream -> stateMachine.data().link(stream, entry));
+ final ClientInvocationId invocationId = ClientInvocationId.valueOf(proto);
+ final CompletableFuture<DataStream> removed = server.getDataStreamMap().remove(invocationId);
+ this.stateMachineFuture = removed == null? stateMachine.data().link(null, entry)
+ : removed.thenApply(stream -> stateMachine.data().link(stream, entry));
} else {
this.stateMachineFuture = null;
}
diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
index 74d4df5..32c5380 100644
--- a/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
+++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/StateMachine.java
@@ -93,7 +93,12 @@ public interface StateMachine extends Closeable {
/**
* Link asynchronously the given stream with the given log entry.
+ * The given stream can be null if it is unavailable due to errors.
+ * In such case, the state machine may either recover the data by itself
+ * or complete the returned future exceptionally.
*
+ * @param stream the stream, which can possibly be null, to be linked.
+ * @param entry the log entry to be linked.
* @return a future for the link task.
*/
default CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
index 28aa30b..3c98523 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamAsyncClusterTests.java
@@ -21,7 +21,14 @@ import org.apache.ratis.MiniRaftCluster;
import org.apache.ratis.RaftTestUtil;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.SingleDataStream;
+import org.apache.ratis.proto.RaftProtos.ReplicationLevel;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.impl.RaftServerImpl;
+import org.apache.ratis.server.impl.RaftServerProxy;
import org.apache.ratis.util.CollectionUtils;
import org.junit.Assert;
import org.junit.Test;
@@ -50,32 +57,57 @@ public abstract class DataStreamAsyncClusterTests<CLUSTER extends MiniRaftCluste
void runTestDataStream(CLUSTER cluster) throws Exception {
RaftTestUtil.waitForLeader(cluster);
- final List<CompletableFuture<Void>> futures = new ArrayList<>();
- futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, 5, 10, 1_000_000, 10), executor));
- futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, 2, 20, 1_000, 10_000), executor));
- futures.forEach(CompletableFuture::join);
+
+ final List<CompletableFuture<Long>> futures = new ArrayList<>();
+ futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, 5, 10, 1_000_000, 10), executor));
+ futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, 2, 20, 1_000, 10_000), executor));
+ final long maxIndex = futures.stream()
+ .map(CompletableFuture::join)
+ .max(Long::compareTo)
+ .orElseThrow(IllegalStateException::new);
+
+ // wait for all servers to catch up
+ try (RaftClient client = cluster.createClient()) {
+ client.async().watch(maxIndex, ReplicationLevel.ALL).join();
+ }
+ // assert all streams are linked
+ for (RaftServerProxy proxy : cluster.getServers()) {
+ final RaftServerImpl impl = proxy.getImpl(cluster.getGroupId());
+ final MultiDataStreamStateMachine stateMachine = (MultiDataStreamStateMachine) impl.getStateMachine();
+ for (SingleDataStream s : stateMachine.getStreams()) {
+ Assert.assertNotNull(s.getLogEntry());
+ }
+ }
}
- void runTestDataStream(CLUSTER cluster, int numClients, int numStreams, int bufferSize, int bufferNum) {
- final List<CompletableFuture<Void>> futures = new ArrayList<>();
+ Long runTestDataStream(CLUSTER cluster, int numClients, int numStreams, int bufferSize, int bufferNum) {
+ final List<CompletableFuture<Long>> futures = new ArrayList<>();
for (int j = 0; j < numClients; j++) {
- futures.add(CompletableFuture.runAsync(() -> runTestDataStream(cluster, numStreams, bufferSize, bufferNum), executor));
+ futures.add(CompletableFuture.supplyAsync(() -> runTestDataStream(cluster, numStreams, bufferSize, bufferNum), executor));
}
Assert.assertEquals(numClients, futures.size());
- futures.forEach(CompletableFuture::join);
+ return futures.stream()
+ .map(CompletableFuture::join)
+ .max(Long::compareTo)
+ .orElseThrow(IllegalStateException::new);
}
- void runTestDataStream(CLUSTER cluster, int numStreams, int bufferSize, int bufferNum) {
+ long runTestDataStream(CLUSTER cluster, int numStreams, int bufferSize, int bufferNum) {
final Iterable<RaftServer> servers = CollectionUtils.as(cluster.getServers(), s -> s);
- final List<CompletableFuture<Void>> futures = new ArrayList<>();
+ final RaftPeerId leader = cluster.getLeader().getId();
+ final List<CompletableFuture<RaftClientReply>> futures = new ArrayList<>();
try(RaftClient client = cluster.createClient()) {
for (int i = 0; i < numStreams; i++) {
final DataStreamOutputImpl out = (DataStreamOutputImpl) client.getDataStreamApi().stream();
- futures.add(CompletableFuture.runAsync(() -> DataStreamTestUtils.writeAndCloseAndAssertReplies(
- servers, out, bufferSize, bufferNum), executor));
+ futures.add(CompletableFuture.supplyAsync(() -> DataStreamTestUtils.writeAndCloseAndAssertReplies(
+ servers, leader, out, bufferSize, bufferNum).join(), executor));
}
Assert.assertEquals(numStreams, futures.size());
- futures.forEach(CompletableFuture::join);
+ return futures.stream()
+ .map(CompletableFuture::join)
+ .map(RaftClientReply::getLogIndex)
+ .max(Long::compareTo)
+ .orElseThrow(IllegalStateException::new);
} catch (IOException e) {
throw new CompletionException(e);
}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
index c08e1bf..0f8d314 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
@@ -30,7 +30,6 @@ import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto;
import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto;
import org.apache.ratis.proto.RaftProtos.InstallSnapshotReplyProto;
import org.apache.ratis.proto.RaftProtos.InstallSnapshotRequestProto;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
import org.apache.ratis.proto.RaftProtos.RequestVoteReplyProto;
import org.apache.ratis.proto.RaftProtos.RequestVoteRequestProto;
import org.apache.ratis.protocol.ClientId;
@@ -352,14 +351,14 @@ abstract class DataStreamBaseTest extends BaseTest {
if (headerException != null) {
final DataStreamReply headerReply = out.getHeaderFuture().join();
Assert.assertFalse(headerReply.isSuccess());
- final RaftClientReply clientReply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
- ((DataStreamReplyByteBuffer)headerReply).slice()));
+ final RaftClientReply clientReply = ClientProtoUtils.toRaftClientReply(
+ ((DataStreamReplyByteBuffer)headerReply).slice());
Assert.assertTrue(clientReply.getException().getMessage().contains(headerException.getMessage()));
return;
}
final RaftClientReply clientReply = DataStreamTestUtils.writeAndCloseAndAssertReplies(
- CollectionUtils.as(servers, Server::getRaftServer), out, bufferSize, bufferNum).join();
+ CollectionUtils.as(servers, Server::getRaftServer), null, out, bufferSize, bufferNum).join();
if (expectedException != null) {
Assert.assertFalse(clientReply.isSuccess());
Assert.assertTrue(clientReply.getException().getMessage().contains(
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index 4020761..28dbc87 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -21,19 +21,23 @@ import org.apache.ratis.BaseTest;
import org.apache.ratis.MiniRaftCluster;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftPeer;
+import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.server.impl.RaftServerImpl;
import org.apache.ratis.server.protocol.TermIndex;
import org.apache.ratis.server.raftlog.RaftLog;
+import org.apache.ratis.util.TimeDuration;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collection;
+import java.util.concurrent.TimeUnit;
import static org.apache.ratis.RaftTestUtil.waitForLeader;
@@ -41,6 +45,10 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
implements MiniRaftCluster.Factory.Get<CLUSTER> {
{
setStateMachine(MultiDataStreamStateMachine.class);
+
+ // Avoid changing leader
+ RaftServerConfigKeys.Rpc.setTimeoutMin(getProperties(), TimeDuration.valueOf(2, TimeUnit.SECONDS));
+ RaftServerConfigKeys.Rpc.setTimeoutMax(getProperties(), TimeDuration.valueOf(3, TimeUnit.SECONDS));
}
public static final int NUM_SERVERS = 3;
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
index d8bd383..2f695a0 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -24,16 +24,19 @@ import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamReply;
+import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientMessage;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
+import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.impl.ServerProtoUtils;
import org.apache.ratis.statemachine.StateMachine.DataStream;
import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
+import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.statemachine.impl.BaseStateMachine;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.util.JavaUtils;
@@ -43,7 +46,10 @@ import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
@@ -76,11 +82,23 @@ public interface DataStreamTestUtils {
@Override
public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
- final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
- s.setLogEntry(entry);
+ LOG.info("link {}", stream);
+ if (stream == null) {
+ return JavaUtils.completeExceptionally(new IllegalStateException("Null stream: entry=" + entry));
+ }
+ ((SingleDataStream)stream).setLogEntry(entry);
return CompletableFuture.completedFuture(null);
}
+ @Override
+ public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
+ final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+ updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
+ final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
+ final ByteString bytesWritten = bytesWritten2ByteString(s.getWritableByteChannel().getBytesWritten());
+ return CompletableFuture.completedFuture(() -> bytesWritten);
+ }
+
SingleDataStream getSingleDataStream(RaftClientRequest request) {
return getSingleDataStream(ClientInvocationId.valueOf(request));
}
@@ -88,6 +106,10 @@ public interface DataStreamTestUtils {
SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
return streams.get(invocationId);
}
+
+ Collection<SingleDataStream> getStreams() {
+ return streams.values();
+ }
}
class SingleDataStream implements DataStream {
@@ -125,6 +147,12 @@ public interface DataStreamTestUtils {
RaftClientRequest getWriteRequest() {
return writeRequest;
}
+
+ @Override
+ public String toString() {
+ return JavaUtils.getClassSimpleName(getClass()) + ": writeRequest=" + writeRequest
+ + ", logEntry=" + ServerProtoUtils.toString(logEntry);
+ }
}
class DataChannel implements StateMachineDataChannel {
@@ -203,7 +231,7 @@ public interface DataStreamTestUtils {
}
static CompletableFuture<RaftClientReply> writeAndCloseAndAssertReplies(
- Iterable<RaftServer> servers, DataStreamOutputImpl out, int bufferSize, int bufferNum) {
+ Iterable<RaftServer> servers, RaftPeerId leader, DataStreamOutputImpl out, int bufferSize, int bufferNum) {
LOG.info("start Stream{}", out.getHeader().getCallId());
final int bytesWritten = writeAndAssertReplies(out, bufferSize, bufferNum);
try {
@@ -213,8 +241,9 @@ public interface DataStreamTestUtils {
} catch (Throwable e) {
throw new CompletionException(e);
}
+ LOG.info("Stream{}: bytesWritten={}", out.getHeader().getCallId(), bytesWritten);
- return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten));
+ return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten, leader));
}
static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize) throws Exception {
@@ -231,21 +260,21 @@ public interface DataStreamTestUtils {
// check writeRequest
final RaftClientRequest writeRequest = stream.getWriteRequest();
Assert.assertEquals(RaftClientRequest.dataStreamRequestType(), writeRequest.getType());
- assertRaftClientMessage(header, writeRequest);
+ assertRaftClientMessage(header, null, writeRequest);
}
static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
- long bytesWritten) {
+ long bytesWritten, RaftPeerId leader) {
// Test close idempotent
Assert.assertSame(dataStreamReply, out.closeAsync().join());
BaseTest.testFailureCase("writeAsync should fail",
() -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
CompletionException.class, (Logger) null, AlreadyClosedException.class);
+ final DataStreamReplyByteBuffer buffer = (DataStreamReplyByteBuffer) dataStreamReply;
try {
- final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
- ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
- assertRaftClientMessage(out.getHeader(), reply);
+ final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(buffer.slice());
+ assertRaftClientMessage(out.getHeader(), leader, reply);
if (reply.isSuccess()) {
final ByteString bytes = reply.getMessage().getContent();
if (!bytes.equals(MOCK)) {
@@ -259,10 +288,10 @@ public interface DataStreamTestUtils {
}
}
- static void assertRaftClientMessage(RaftClientMessage expected, RaftClientMessage computed) {
+ static void assertRaftClientMessage(RaftClientMessage expected, RaftPeerId expectedServerId, RaftClientMessage computed) {
Assert.assertNotNull(computed);
Assert.assertEquals(expected.getClientId(), computed.getClientId());
- Assert.assertEquals(expected.getServerId(), computed.getServerId());
+ Assert.assertEquals(Optional.ofNullable(expectedServerId).orElseGet(expected::getServerId), computed.getServerId());
Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
Assert.assertEquals(expected.getCallId(), computed.getCallId());
}