You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2022/08/06 14:38:30 UTC
[flink] branch master updated: [FLINK-27524][datastream] Introduce cache API to DataStream
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new cf1a29d47a5 [FLINK-27524][datastream] Introduce cache API to DataStream
cf1a29d47a5 is described below
commit cf1a29d47a5bb4fb92e98a36934e525d74bae17b
Author: sxnan <su...@gmail.com>
AuthorDate: Thu Jun 30 23:14:30 2022 +0800
[FLINK-27524][datastream] Introduce cache API to DataStream
This closes #20147.
---
.../executors/AbstractSessionClusterExecutor.java | 44 +-
.../apache/flink/client/program/ClusterClient.java | 22 +
.../flink/client/program/MiniClusterClient.java | 13 +
.../client/program/rest/RestClusterClient.java | 60 +++
.../client/program/rest/RestClusterClientTest.java | 206 +++++++++
.../execution/CacheSupportedPipelineExecutor.java | 55 +++
...st_stream_execution_environment_completeness.py | 3 +-
.../TaskDeploymentDescriptorFactory.java | 6 +-
.../io/network/partition/DataSetMetaInfo.java | 8 +-
.../flink/runtime/minicluster/MiniCluster.java | 25 ++
.../rest/messages/dataset/ClusterDataSetEntry.java | 2 +-
....java => ClusterDatasetCorruptedException.java} | 19 +-
.../flink/runtime/scheduler/DefaultScheduler.java | 7 +-
.../TaskDeploymentDescriptorFactoryTest.java | 4 +-
.../jobmaster/JobIntermediateDatasetReuseTest.java | 61 ++-
.../streaming/api/datastream/CachedDataStream.java | 61 +++
.../api/datastream/SideOutputDataStream.java | 56 +++
.../api/datastream/SingleOutputStreamOperator.java | 22 +-
.../environment/StreamExecutionEnvironment.java | 151 ++++++-
.../flink/streaming/api/graph/StreamEdge.java | 20 +-
.../flink/streaming/api/graph/StreamGraph.java | 30 +-
.../streaming/api/graph/StreamGraphGenerator.java | 3 +
.../flink/streaming/api/graph/StreamNode.java | 13 +
.../api/graph/StreamingJobGraphGenerator.java | 35 +-
.../api/transformations/CacheTransformation.java | 89 ++++
.../translators/CacheTransformationTranslator.java | 205 +++++++++
.../api/graph/StreamGraphGeneratorTest.java | 290 ++++++++++---
.../api/graph/StreamingJobGraphGeneratorTest.java | 474 +++++++++++++--------
.../streaming/api/scala/CachedDataStream.scala | 29 ++
.../flink/streaming/api/scala/DataStream.scala | 8 +
.../api/scala/StreamExecutionEnvironment.scala | 10 +-
.../scala/StreamingScalaAPICompletenessTest.scala | 6 +
.../MiniClusterPipelineExecutorServiceLoader.java | 20 +-
.../flink/test/streaming/runtime/CacheITCase.java | 328 ++++++++++++++
34 files changed, 2043 insertions(+), 342 deletions(-)
diff --git a/flink-clients/src/main/java/org/apache/flink/client/deployment/executors/AbstractSessionClusterExecutor.java b/flink-clients/src/main/java/org/apache/flink/client/deployment/executors/AbstractSessionClusterExecutor.java
index 7224375a5ee..eeff536de51 100644
--- a/flink-clients/src/main/java/org/apache/flink/client/deployment/executors/AbstractSessionClusterExecutor.java
+++ b/flink-clients/src/main/java/org/apache/flink/client/deployment/executors/AbstractSessionClusterExecutor.java
@@ -27,13 +27,17 @@ import org.apache.flink.client.deployment.ClusterDescriptor;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.client.program.ClusterClientProvider;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.CacheSupportedPipelineExecutor;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.core.execution.PipelineExecutor;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.function.FunctionUtils;
import javax.annotation.Nonnull;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -50,7 +54,7 @@ import static org.apache.flink.util.Preconditions.checkState;
@Internal
public class AbstractSessionClusterExecutor<
ClusterID, ClientFactory extends ClusterClientFactory<ClusterID>>
- implements PipelineExecutor {
+ implements CacheSupportedPipelineExecutor {
private final ClientFactory clusterClientFactory;
@@ -95,4 +99,42 @@ public class AbstractSessionClusterExecutor<
.whenCompleteAsync((ignored1, ignored2) -> clusterClient.close());
}
}
+
+ @Override
+ public CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds(
+ Configuration configuration, ClassLoader userCodeClassloader) throws Exception {
+
+ try (final ClusterDescriptor<ClusterID> clusterDescriptor =
+ clusterClientFactory.createClusterDescriptor(configuration)) {
+ final ClusterID clusterID = clusterClientFactory.getClusterId(configuration);
+ checkState(clusterID != null);
+
+ final ClusterClientProvider<ClusterID> clusterClientProvider =
+ clusterDescriptor.retrieve(clusterID);
+
+ final ClusterClient<ClusterID> clusterClient = clusterClientProvider.getClusterClient();
+ return clusterClient.listCompletedClusterDatasetIds();
+ }
+ }
+
+ @Override
+ public CompletableFuture<Void> invalidateClusterDataset(
+ AbstractID clusterDatasetId,
+ Configuration configuration,
+ ClassLoader userCodeClassloader)
+ throws Exception {
+ try (final ClusterDescriptor<ClusterID> clusterDescriptor =
+ clusterClientFactory.createClusterDescriptor(configuration)) {
+ final ClusterID clusterID = clusterClientFactory.getClusterId(configuration);
+ checkState(clusterID != null);
+
+ final ClusterClientProvider<ClusterID> clusterClientProvider =
+ clusterDescriptor.retrieve(clusterID);
+
+ final ClusterClient<ClusterID> clusterClient = clusterClientProvider.getClusterClient();
+ return clusterClient
+ .invalidateClusterDataset(new IntermediateDataSetID(clusterDatasetId))
+ .thenCompose(acknowledge -> null);
+ }
+ }
}
diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java
index b6b1ca08950..0cbc27487d5 100644
--- a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java
+++ b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java
@@ -29,12 +29,15 @@ import org.apache.flink.runtime.jobmaster.JobResult;
import org.apache.flink.runtime.messages.Acknowledge;
import org.apache.flink.runtime.operators.coordination.CoordinationRequest;
import org.apache.flink.runtime.operators.coordination.CoordinationResponse;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.FlinkException;
import javax.annotation.Nullable;
import java.util.Collection;
+import java.util.Collections;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
/**
@@ -184,4 +187,23 @@ public interface ClusterClient<T> extends AutoCloseable {
*/
CompletableFuture<CoordinationResponse> sendCoordinationRequest(
JobID jobId, OperatorID operatorId, CoordinationRequest request);
+
+ /**
+ * Return a set of ids of the completed cluster datasets.
+ *
+ * @return A set of ids of the completely cached intermediate dataset.
+ */
+ default CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds() {
+ return CompletableFuture.completedFuture(Collections.emptySet());
+ }
+
+ /**
+ * Invalidate the cached intermediate dataset with the given id.
+ *
+ * @param clusterDatasetId id of the cluster dataset to be invalidated.
+ * @return Future which will be completed when the cached dataset is invalidated.
+ */
+ default CompletableFuture<Void> invalidateClusterDataset(AbstractID clusterDatasetId) {
+ return CompletableFuture.completedFuture(null);
+ }
}
diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/MiniClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/MiniClusterClient.java
index 9a49991b221..80b949daf6f 100644
--- a/flink-clients/src/main/java/org/apache/flink/client/program/MiniClusterClient.java
+++ b/flink-clients/src/main/java/org/apache/flink/client/program/MiniClusterClient.java
@@ -26,6 +26,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.runtime.client.JobStatusMessage;
import org.apache.flink.runtime.executiongraph.AccessExecutionGraph;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobmaster.JobResult;
@@ -33,6 +34,7 @@ import org.apache.flink.runtime.messages.Acknowledge;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.operators.coordination.CoordinationRequest;
import org.apache.flink.runtime.operators.coordination.CoordinationResponse;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.concurrent.FutureUtils;
@@ -46,6 +48,7 @@ import javax.annotation.Nullable;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
@@ -182,6 +185,16 @@ public class MiniClusterClient implements ClusterClient<MiniClusterClient.MiniCl
}
}
+ @Override
+ public CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds() {
+ return miniCluster.listCompletedClusterDatasetIds();
+ }
+
+ @Override
+ public CompletableFuture<Void> invalidateClusterDataset(AbstractID clusterDatasetId) {
+ return miniCluster.invalidateClusterDataset(new IntermediateDataSetID(clusterDatasetId));
+ }
+
/** The type of the Cluster ID for the local {@link MiniCluster}. */
public enum MiniClusterId {
INSTANCE
diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
index 0893f223da5..6ec95eba137 100644
--- a/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
+++ b/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.clusterframework.ApplicationStatus;
import org.apache.flink.runtime.highavailability.ClientHighAvailabilityServices;
import org.apache.flink.runtime.highavailability.ClientHighAvailabilityServicesFactory;
import org.apache.flink.runtime.highavailability.DefaultClientHighAvailabilityServicesFactory;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobmaster.JobResult;
@@ -66,6 +67,12 @@ import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.messages.TerminationModeQueryParameter;
import org.apache.flink.runtime.rest.messages.TriggerId;
import org.apache.flink.runtime.rest.messages.cluster.ShutdownHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteStatusHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteStatusMessageParameters;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteTriggerHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteTriggerMessageParameters;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetEntry;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetListHeaders;
import org.apache.flink.runtime.rest.messages.job.JobDetailsHeaders;
import org.apache.flink.runtime.rest.messages.job.JobDetailsInfo;
import org.apache.flink.runtime.rest.messages.job.JobExecutionResultHeaders;
@@ -93,10 +100,12 @@ import org.apache.flink.runtime.rest.messages.queue.QueueStatus;
import org.apache.flink.runtime.rest.util.RestClientException;
import org.apache.flink.runtime.rest.util.RestConstants;
import org.apache.flink.runtime.webmonitor.retriever.LeaderRetriever;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.ExecutorUtils;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.SerializedValue;
+import org.apache.flink.util.StringUtils;
import org.apache.flink.util.concurrent.ExecutorThreadFactory;
import org.apache.flink.util.concurrent.FixedRetryStrategy;
import org.apache.flink.util.concurrent.FutureUtils;
@@ -124,6 +133,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
@@ -667,6 +677,56 @@ public class RestClusterClient<T> implements ClusterClient<T> {
});
}
+ @Override
+ public CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds() {
+ return sendRequest(ClusterDataSetListHeaders.INSTANCE)
+ .thenApply(
+ clusterDataSetListResponseBody ->
+ clusterDataSetListResponseBody.getDataSets().stream()
+ .filter(ClusterDataSetEntry::isComplete)
+ .map(ClusterDataSetEntry::getDataSetId)
+ .map(id -> new AbstractID(StringUtils.hexStringToByte(id)))
+ .collect(Collectors.toSet()));
+ }
+
+ @Override
+ public CompletableFuture<Void> invalidateClusterDataset(AbstractID clusterDatasetId) {
+
+ final ClusterDataSetDeleteTriggerHeaders triggerHeader =
+ ClusterDataSetDeleteTriggerHeaders.INSTANCE;
+ final ClusterDataSetDeleteTriggerMessageParameters parameters =
+ triggerHeader.getUnresolvedMessageParameters();
+ parameters.clusterDataSetIdPathParameter.resolve(
+ new IntermediateDataSetID(clusterDatasetId));
+
+ final CompletableFuture<TriggerResponse> triggerFuture =
+ sendRequest(triggerHeader, parameters);
+
+ final CompletableFuture<AsynchronousOperationInfo> clusterDatasetDeleteFuture =
+ triggerFuture.thenCompose(
+ triggerResponse -> {
+ final TriggerId triggerId = triggerResponse.getTriggerId();
+ final ClusterDataSetDeleteStatusHeaders statusHeaders =
+ ClusterDataSetDeleteStatusHeaders.INSTANCE;
+ final ClusterDataSetDeleteStatusMessageParameters
+ statusMessageParameters =
+ statusHeaders.getUnresolvedMessageParameters();
+ statusMessageParameters.triggerIdPathParameter.resolve(triggerId);
+
+ return pollResourceAsync(
+ () -> sendRequest(statusHeaders, statusMessageParameters));
+ });
+
+ return clusterDatasetDeleteFuture.thenApply(
+ asynchronousOperationInfo -> {
+ if (asynchronousOperationInfo.getFailureCause() == null) {
+ return null;
+ } else {
+ throw new CompletionException(asynchronousOperationInfo.getFailureCause());
+ }
+ });
+ }
+
@Override
public void shutDownCluster() {
try {
diff --git a/flink-clients/src/test/java/org/apache/flink/client/program/rest/RestClusterClientTest.java b/flink-clients/src/test/java/org/apache/flink/client/program/rest/RestClusterClientTest.java
index b730135c205..9daa4e48f82 100644
--- a/flink-clients/src/test/java/org/apache/flink/client/program/rest/RestClusterClientTest.java
+++ b/flink-clients/src/test/java/org/apache/flink/client/program/rest/RestClusterClientTest.java
@@ -35,6 +35,8 @@ import org.apache.flink.configuration.RestOptions;
import org.apache.flink.runtime.client.JobStatusMessage;
import org.apache.flink.runtime.clusterframework.ApplicationStatus;
import org.apache.flink.runtime.dispatcher.DispatcherGateway;
+import org.apache.flink.runtime.io.network.partition.DataSetMetaInfo;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -72,6 +74,13 @@ import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.messages.RuntimeMessageHeaders;
import org.apache.flink.runtime.rest.messages.TriggerId;
import org.apache.flink.runtime.rest.messages.TriggerIdPathParameter;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteStatusHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteStatusMessageParameters;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteTriggerHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetDeleteTriggerMessageParameters;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetIdPathParameter;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetListHeaders;
+import org.apache.flink.runtime.rest.messages.dataset.ClusterDataSetListResponseBody;
import org.apache.flink.runtime.rest.messages.job.JobExecutionResultHeaders;
import org.apache.flink.runtime.rest.messages.job.JobExecutionResultResponseBody;
import org.apache.flink.runtime.rest.messages.job.JobStatusInfoHeaders;
@@ -91,6 +100,7 @@ import org.apache.flink.runtime.rest.util.TestRestServerEndpoint;
import org.apache.flink.runtime.rpc.RpcUtils;
import org.apache.flink.runtime.webmonitor.TestingDispatcherGateway;
import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.ConfigurationException;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkException;
@@ -123,12 +133,14 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@@ -437,6 +449,200 @@ class RestClusterClientTest {
}
}
+ @Test
+ public void testListCompletedClusterDatasetIds() {
+ Set<AbstractID> expectedCompletedClusterDatasetIds = new HashSet<>();
+ expectedCompletedClusterDatasetIds.add(new AbstractID());
+ expectedCompletedClusterDatasetIds.add(new AbstractID());
+
+ try (TestRestServerEndpoint restServerEndpoint =
+ createRestServerEndpoint(
+ new TestListCompletedClusterDatasetHandler(
+ expectedCompletedClusterDatasetIds))) {
+ try (RestClusterClient<?> restClusterClient =
+ createRestClusterClient(restServerEndpoint.getServerAddress().getPort())) {
+ final Set<AbstractID> returnedIds =
+ restClusterClient.listCompletedClusterDatasetIds().get();
+ assertThat(returnedIds).isEqualTo(expectedCompletedClusterDatasetIds);
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private class TestListCompletedClusterDatasetHandler
+ extends TestHandler<
+ EmptyRequestBody, ClusterDataSetListResponseBody, EmptyMessageParameters> {
+
+ private final Set<AbstractID> intermediateDataSetIds;
+
+ private TestListCompletedClusterDatasetHandler(Set<AbstractID> intermediateDataSetIds) {
+ super(ClusterDataSetListHeaders.INSTANCE);
+ this.intermediateDataSetIds = intermediateDataSetIds;
+ }
+
+ @Override
+ protected CompletableFuture<ClusterDataSetListResponseBody> handleRequest(
+ @Nonnull HandlerRequest<EmptyRequestBody> request,
+ @Nonnull DispatcherGateway gateway)
+ throws RestHandlerException {
+
+ Map<IntermediateDataSetID, DataSetMetaInfo> datasets = new HashMap<>();
+ intermediateDataSetIds.forEach(
+ id ->
+ datasets.put(
+ new IntermediateDataSetID(id),
+ DataSetMetaInfo.withNumRegisteredPartitions(1, 1)));
+ return CompletableFuture.completedFuture(ClusterDataSetListResponseBody.from(datasets));
+ }
+ }
+
+ @Test
+ public void testInvalidateClusterDataset() throws Exception {
+ final IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+ final String exceptionMessage = "Test exception.";
+ final FlinkException testException = new FlinkException(exceptionMessage);
+
+ final TestClusterDatasetDeleteHandlers testClusterDatasetDeleteHandlers =
+ new TestClusterDatasetDeleteHandlers(intermediateDataSetID);
+ final TestClusterDatasetDeleteHandlers.TestClusterDatasetDeleteTriggerHandler
+ testClusterDatasetDeleteTriggerHandler =
+ testClusterDatasetDeleteHandlers
+ .new TestClusterDatasetDeleteTriggerHandler();
+ final TestClusterDatasetDeleteHandlers.TestClusterDatasetDeleteStatusHandler
+ testClusterDatasetDeleteStatusHandler =
+ testClusterDatasetDeleteHandlers
+ .new TestClusterDatasetDeleteStatusHandler(
+ OptionalFailure.of(AsynchronousOperationInfo.complete()),
+ OptionalFailure.of(
+ AsynchronousOperationInfo.completeExceptional(
+ new SerializedThrowable(testException))),
+ OptionalFailure.ofFailure(testException));
+
+ try (TestRestServerEndpoint restServerEndpoint =
+ createRestServerEndpoint(
+ testClusterDatasetDeleteStatusHandler,
+ testClusterDatasetDeleteTriggerHandler)) {
+ RestClusterClient<?> restClusterClient =
+ createRestClusterClient(restServerEndpoint.getServerAddress().getPort());
+
+ try {
+ {
+ final CompletableFuture<Void> invalidateCacheFuture =
+ restClusterClient.invalidateClusterDataset(intermediateDataSetID);
+ assertThat(invalidateCacheFuture.get()).isNull();
+ }
+
+ {
+ final CompletableFuture<Void> invalidateCacheFuture =
+ restClusterClient.invalidateClusterDataset(intermediateDataSetID);
+
+ try {
+ invalidateCacheFuture.get();
+ fail("Expected an exception");
+ } catch (ExecutionException ee) {
+ assertThat(
+ ExceptionUtils.findThrowableWithMessage(
+ ee, exceptionMessage)
+ .isPresent())
+ .isTrue();
+ }
+ }
+
+ {
+ try {
+ restClusterClient.invalidateClusterDataset(intermediateDataSetID).get();
+ fail("Expected an exception.");
+ } catch (ExecutionException ee) {
+ assertThat(
+ ExceptionUtils.findThrowable(ee, RestClientException.class)
+ .isPresent())
+ .isTrue();
+ }
+ }
+ } finally {
+ restClusterClient.close();
+ }
+ }
+ }
+
+ private class TestClusterDatasetDeleteHandlers {
+
+ private final TriggerId triggerId = new TriggerId();
+
+ private final IntermediateDataSetID intermediateDataSetID;
+
+ private TestClusterDatasetDeleteHandlers(IntermediateDataSetID intermediateDatasetId) {
+ this.intermediateDataSetID = Preconditions.checkNotNull(intermediateDatasetId);
+ }
+
+ private class TestClusterDatasetDeleteTriggerHandler
+ extends TestHandler<
+ EmptyRequestBody,
+ TriggerResponse,
+ ClusterDataSetDeleteTriggerMessageParameters> {
+ private TestClusterDatasetDeleteTriggerHandler() {
+ super(ClusterDataSetDeleteTriggerHeaders.INSTANCE);
+ }
+
+ @Override
+ protected CompletableFuture<TriggerResponse> handleRequest(
+ HandlerRequest<EmptyRequestBody> request, DispatcherGateway gateway)
+ throws RestHandlerException {
+ assertThat(request.getPathParameter(ClusterDataSetIdPathParameter.class))
+ .isEqualTo(intermediateDataSetID);
+ return CompletableFuture.completedFuture(new TriggerResponse(triggerId));
+ }
+ }
+
+ private class TestClusterDatasetDeleteStatusHandler
+ extends TestHandler<
+ EmptyRequestBody,
+ AsynchronousOperationResult<AsynchronousOperationInfo>,
+ ClusterDataSetDeleteStatusMessageParameters> {
+
+ private final Queue<OptionalFailure<AsynchronousOperationInfo>> responses;
+
+ private TestClusterDatasetDeleteStatusHandler(
+ OptionalFailure<AsynchronousOperationInfo>... responses) {
+ super(ClusterDataSetDeleteStatusHeaders.INSTANCE);
+ this.responses = new ArrayDeque<>(Arrays.asList(responses));
+ }
+
+ @Override
+ protected CompletableFuture<AsynchronousOperationResult<AsynchronousOperationInfo>>
+ handleRequest(
+ @Nonnull HandlerRequest<EmptyRequestBody> request,
+ @Nonnull DispatcherGateway gateway)
+ throws RestHandlerException {
+ final TriggerId actualTriggerId =
+ request.getPathParameter(TriggerIdPathParameter.class);
+
+ if (actualTriggerId.equals(triggerId)) {
+ final OptionalFailure<AsynchronousOperationInfo> nextResponse =
+ responses.poll();
+
+ if (nextResponse != null) {
+ if (nextResponse.isFailure()) {
+ throw new RestHandlerException(
+ "Failure",
+ HttpResponseStatus.BAD_REQUEST,
+ nextResponse.getFailureCause());
+ } else {
+ return CompletableFuture.completedFuture(
+ AsynchronousOperationResult.completed(
+ nextResponse.getUnchecked()));
+ }
+ } else {
+ throw new AssertionError();
+ }
+ } else {
+ throw new AssertionError();
+ }
+ }
+ }
+ }
+
@Test
void testListJobs() throws Exception {
try (TestRestServerEndpoint restServerEndpoint =
diff --git a/flink-core/src/main/java/org/apache/flink/core/execution/CacheSupportedPipelineExecutor.java b/flink-core/src/main/java/org/apache/flink/core/execution/CacheSupportedPipelineExecutor.java
new file mode 100644
index 00000000000..0ccbe626a39
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/core/execution/CacheSupportedPipelineExecutor.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.core.execution;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.util.AbstractID;
+
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+
+/** The pipeline executor that support caching intermediate dataset. */
+@Internal
+public interface CacheSupportedPipelineExecutor extends PipelineExecutor {
+
+ /**
+ * Return a set of ids of the completed cluster dataset.
+ *
+ * @param configuration the {@link Configuration} with the required parameters
+ * @param userCodeClassloader the {@link ClassLoader} to deserialize usercode
+ * @return A set of ids of the completely cached intermediate dataset.
+ */
+ CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds(
+ final Configuration configuration, final ClassLoader userCodeClassloader)
+ throws Exception;
+
+ /**
+ * Invalidate the cluster dataset with the given id.
+ *
+ * @param clusterDatasetId id of the cluster dataset to be invalidated.
+ * @param configuration the {@link Configuration} with the required parameters
+ * @param userCodeClassloader the {@link ClassLoader} to deserialize usercode
+ * @return Future which will be completed when the cached dataset is invalidated.
+ */
+ CompletableFuture<Void> invalidateClusterDataset(
+ AbstractID clusterDatasetId,
+ final Configuration configuration,
+ final ClassLoader userCodeClassloader)
+ throws Exception;
+}
diff --git a/flink-python/pyflink/datastream/tests/test_stream_execution_environment_completeness.py b/flink-python/pyflink/datastream/tests/test_stream_execution_environment_completeness.py
index 7fcefe5d21e..df3453dd730 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment_completeness.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment_completeness.py
@@ -50,7 +50,8 @@ class StreamExecutionEnvironmentCompletenessTests(PythonAPICompletenessTestCase,
'setNumberOfExecutionRetries', 'executeAsync', 'registerJobListener',
'clearJobListeners', 'getJobListeners', 'fromSequence', 'getConfiguration',
'generateStreamGraph', 'getTransformations', 'areExplicitEnvironmentsAllowed',
- 'registerCollectIterator'}
+ 'registerCollectIterator', 'listCompletedClusterDatasets',
+ 'invalidateClusterDataset', 'registerCacheTransformation', 'close'}
if __name__ == '__main__':
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index 21f5ef6f069..a2c45ed0c47 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -39,7 +39,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobType;
-import org.apache.flink.runtime.scheduler.CachedIntermediateDataSetCorruptedException;
+import org.apache.flink.runtime.scheduler.ClusterDatasetCorruptedException;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
@@ -254,7 +254,7 @@ public class TaskDeploymentDescriptorFactory {
}
public static TaskDeploymentDescriptorFactory fromExecution(Execution execution)
- throws IOException, CachedIntermediateDataSetCorruptedException {
+ throws IOException, ClusterDatasetCorruptedException {
final ExecutionVertex executionVertex = execution.getVertex();
final InternalExecutionGraphAccessor internalExecutionGraphAccessor =
executionVertex.getExecutionGraphAccessor();
@@ -263,7 +263,7 @@ public class TaskDeploymentDescriptorFactory {
clusterPartitionShuffleDescriptors =
getClusterPartitionShuffleDescriptors(executionVertex);
} catch (Throwable e) {
- throw new CachedIntermediateDataSetCorruptedException(
+ throw new ClusterDatasetCorruptedException(
e,
executionVertex
.getJobVertex()
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataSetMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataSetMetaInfo.java
index ae446ad5bf4..cd7bb7a0c0b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataSetMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataSetMetaInfo.java
@@ -21,14 +21,16 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.util.Preconditions;
+import java.io.Serializable;
import java.util.Comparator;
import java.util.Map;
import java.util.OptionalInt;
import java.util.SortedMap;
import java.util.TreeMap;
+import java.util.function.ToIntFunction;
/** Container for meta-data of a data set. */
-public final class DataSetMetaInfo {
+public final class DataSetMetaInfo implements Serializable {
private static final int UNKNOWN = -1;
private final int numRegisteredPartitions;
@@ -36,7 +38,9 @@ public final class DataSetMetaInfo {
private final SortedMap<ResultPartitionID, ShuffleDescriptor>
shuffleDescriptorsOrderByPartitionId =
new TreeMap<>(
- Comparator.comparingInt(o -> o.getPartitionId().getPartitionNumber()));
+ Comparator.comparingInt(
+ (ToIntFunction<? super ResultPartitionID> & Serializable)
+ o -> o.getPartitionId().getPartitionNumber()));
private DataSetMetaInfo(int numRegisteredPartitions, int numTotalPartitions) {
this.numRegisteredPartitions = numRegisteredPartitions;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java
index dc6d09fe4f1..3a11fa93473 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java
@@ -60,6 +60,8 @@ import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServicesWithLeadershipControl;
import org.apache.flink.runtime.highavailability.nonha.embedded.HaLeadershipControl;
+import org.apache.flink.runtime.io.network.partition.ClusterPartitionManager;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.RestoreMode;
@@ -95,6 +97,7 @@ import org.apache.flink.runtime.webmonitor.retriever.LeaderRetriever;
import org.apache.flink.runtime.webmonitor.retriever.MetricQueryServiceRetriever;
import org.apache.flink.runtime.webmonitor.retriever.impl.RpcGatewayRetriever;
import org.apache.flink.runtime.webmonitor.retriever.impl.RpcMetricQueryServiceRetriever;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.AutoCloseableAsync;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.ExecutorUtils;
@@ -123,8 +126,10 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@@ -1318,6 +1323,26 @@ public class MiniCluster implements AutoCloseableAsync {
}
}
+ public CompletableFuture<Void> invalidateClusterDataset(AbstractID clusterDatasetId) {
+ return resourceManagerGatewayRetriever
+ .getFuture()
+ .thenApply(
+ resourceManagerGateway ->
+ resourceManagerGateway.releaseClusterPartitions(
+ new IntermediateDataSetID(clusterDatasetId)))
+ .thenCompose(Function.identity());
+ }
+
+ public CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds() {
+ return resourceManagerGatewayRetriever
+ .getFuture()
+ .thenApply(ClusterPartitionManager::listDataSets)
+ .thenCompose(
+ metaInfoMapFuture ->
+ metaInfoMapFuture.thenApply(
+ metaInfoMap -> new HashSet<>(metaInfoMap.keySet())));
+ }
+
/** Internal factory for {@link RpcService}. */
protected interface RpcServiceFactory {
RpcService createRpcService() throws Exception;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/dataset/ClusterDataSetEntry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/dataset/ClusterDataSetEntry.java
index 00e6dd8e049..0399bfea373 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/dataset/ClusterDataSetEntry.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/dataset/ClusterDataSetEntry.java
@@ -28,7 +28,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonPro
*
* @see ClusterDataSetListResponseBody
*/
-class ClusterDataSetEntry {
+public class ClusterDataSetEntry {
private static final String FIELD_NAME_DATA_SET_ID = "id";
private static final String FIELD_NAME_COMPLETE = "isComplete";
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/CachedIntermediateDataSetCorruptedException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ClusterDatasetCorruptedException.java
similarity index 66%
rename from flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/CachedIntermediateDataSetCorruptedException.java
rename to flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ClusterDatasetCorruptedException.java
index 1740a79fd31..a088f24d36c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/CachedIntermediateDataSetCorruptedException.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ClusterDatasetCorruptedException.java
@@ -24,21 +24,20 @@ import org.apache.flink.runtime.throwable.ThrowableType;
import java.util.List;
-/** Indicates some task fail to consume cached intermediate dataset. */
+/** Indicates some task fail to consume cluster dataset. */
@ThrowableAnnotation(ThrowableType.NonRecoverableError)
-public class CachedIntermediateDataSetCorruptedException extends JobException {
- private final List<IntermediateDataSetID> corruptedIntermediateDataSetID;
+public class ClusterDatasetCorruptedException extends JobException {
+ private final List<IntermediateDataSetID> corruptedClusterDatasetIds;
- public CachedIntermediateDataSetCorruptedException(
- Throwable cause, List<IntermediateDataSetID> corruptedIntermediateDataSetID) {
+ public ClusterDatasetCorruptedException(
+ Throwable cause, List<IntermediateDataSetID> corruptedClusterDatasetIds) {
super(
- String.format(
- "Corrupted intermediate dataset IDs: %s", corruptedIntermediateDataSetID),
+ String.format("Corrupted cluster dataset IDs: %s", corruptedClusterDatasetIds),
cause);
- this.corruptedIntermediateDataSetID = corruptedIntermediateDataSetID;
+ this.corruptedClusterDatasetIds = corruptedClusterDatasetIds;
}
- public List<IntermediateDataSetID> getCorruptedIntermediateDataSetID() {
- return corruptedIntermediateDataSetID;
+ public List<IntermediateDataSetID> getCorruptedClusterDatasetIds() {
+ return corruptedClusterDatasetIds;
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
index 57afb856f21..da59e2275fd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
@@ -239,8 +239,7 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
execution.getFailureInfo().get().getException().deserializeError(userCodeLoader);
handleTaskFailure(
execution,
- maybeTranslateToCachedIntermediateDataSetException(
- error, execution.getVertex().getID()));
+ maybeTranslateToClusterDatasetException(error, execution.getVertex().getID()));
}
protected void handleTaskFailure(
@@ -257,7 +256,7 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
return executionFailureHandler.getFailureHandlingResult(failedExecution, error, timestamp);
}
- private Throwable maybeTranslateToCachedIntermediateDataSetException(
+ private Throwable maybeTranslateToClusterDatasetException(
@Nullable Throwable cause, ExecutionVertexID failedVertex) {
if (!(cause instanceof PartitionException)) {
return cause;
@@ -275,7 +274,7 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
return cause;
}
- return new CachedIntermediateDataSetCorruptedException(
+ return new ClusterDatasetCorruptedException(
cause, Collections.singletonList(failedPartitionId.getIntermediateDataSetID()));
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
index 1074da2d340..32c8b54357b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
@@ -39,7 +39,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
-import org.apache.flink.runtime.scheduler.CachedIntermediateDataSetCorruptedException;
+import org.apache.flink.runtime.scheduler.ClusterDatasetCorruptedException;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
@@ -162,7 +162,7 @@ public class TaskDeploymentDescriptorFactoryTest extends TestLogger {
}
private static TaskDeploymentDescriptor createTaskDeploymentDescriptor(ExecutionVertex ev)
- throws IOException, CachedIntermediateDataSetCorruptedException {
+ throws IOException, ClusterDatasetCorruptedException {
return TaskDeploymentDescriptorFactory.fromExecution(ev.getCurrentExecutionAttempt())
.createDeploymentDescriptor(new AllocationID(), null, Collections.emptyList());
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
index 907bef037ae..7f9488fe901 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
@@ -32,22 +32,17 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.minicluster.TestingMiniCluster;
import org.apache.flink.runtime.minicluster.TestingMiniClusterConfiguration;
-import org.apache.flink.runtime.scheduler.CachedIntermediateDataSetCorruptedException;
+import org.apache.flink.runtime.scheduler.ClusterDatasetCorruptedException;
import org.apache.flink.types.IntValue;
-import org.junit.Test;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-
/** Integration tests for reusing persisted intermediate dataset */
public class JobIntermediateDatasetReuseTest {
@@ -56,12 +51,14 @@ public class JobIntermediateDatasetReuseTest {
@Test
public void testClusterPartitionReuse() throws Exception {
- internalTestClusterPartitionReuse(1, 1, jobResult -> assertTrue(jobResult.isSuccess()));
+ internalTestClusterPartitionReuse(
+ 1, 1, jobResult -> Assertions.assertThat(jobResult.isSuccess()).isTrue());
}
@Test
public void testClusterPartitionReuseMultipleParallelism() throws Exception {
- internalTestClusterPartitionReuse(64, 64, jobResult -> assertTrue(jobResult.isSuccess()));
+ internalTestClusterPartitionReuse(
+ 64, 64, jobResult -> Assertions.assertThat(jobResult.isSuccess()).isTrue());
}
@Test
@@ -71,8 +68,9 @@ public class JobIntermediateDatasetReuseTest {
1,
2,
jobResult -> {
- assertFalse(jobResult.isSuccess());
- assertNotNull(getCachedIntermediateDataSetCorruptedException(jobResult));
+ Assertions.assertThat(jobResult.isSuccess()).isFalse();
+ Assertions.assertThat(getClusterDatasetCorruptedException(jobResult))
+ .isNotNull();
});
}
@@ -83,8 +81,9 @@ public class JobIntermediateDatasetReuseTest {
2,
1,
jobResult -> {
- assertFalse(jobResult.isSuccess());
- assertNotNull(getCachedIntermediateDataSetCorruptedException(jobResult));
+ Assertions.assertThat(jobResult.isSuccess()).isFalse();
+ Assertions.assertThat(getClusterDatasetCorruptedException(jobResult))
+ .isNotNull();
});
}
@@ -107,7 +106,7 @@ public class JobIntermediateDatasetReuseTest {
CompletableFuture<JobResult> jobResultFuture =
miniCluster.requestJobResult(firstJobGraph.getJobID());
JobResult jobResult = jobResultFuture.get();
- assertTrue(jobResult.isSuccess());
+ Assertions.assertThat(jobResult.isSuccess()).isTrue();
final JobGraph secondJobGraph =
createSecondJobGraph(consumerParallelism, intermediateDataSetID);
@@ -133,7 +132,7 @@ public class JobIntermediateDatasetReuseTest {
CompletableFuture<JobResult> jobResultFuture =
miniCluster.requestJobResult(firstJobGraph.getJobID());
JobResult jobResult = jobResultFuture.get();
- assertTrue(jobResult.isSuccess());
+ Assertions.assertThat(jobResult.isSuccess()).isTrue();
miniCluster.terminateTaskManager(0);
miniCluster.startTaskManager();
@@ -145,38 +144,38 @@ public class JobIntermediateDatasetReuseTest {
miniCluster.submitJob(secondJobGraph).get();
jobResultFuture = miniCluster.requestJobResult(secondJobGraph.getJobID());
jobResult = jobResultFuture.get();
- assertFalse(jobResult.isSuccess());
- final CachedIntermediateDataSetCorruptedException exception =
- getCachedIntermediateDataSetCorruptedException(jobResult);
- assertNotNull(exception);
- assertEquals(
- intermediateDataSetID, exception.getCorruptedIntermediateDataSetID().get(0));
+ Assertions.assertThat(jobResult.isSuccess()).isFalse();
+ final ClusterDatasetCorruptedException exception =
+ getClusterDatasetCorruptedException(jobResult);
+ Assertions.assertThat(exception).isNotNull();
+ Assertions.assertThat(exception.getCorruptedClusterDatasetIds().get(0))
+ .isEqualTo(intermediateDataSetID);
firstJobGraph.setJobID(new JobID());
miniCluster.submitJob(firstJobGraph).get();
jobResultFuture = miniCluster.requestJobResult(firstJobGraph.getJobID());
jobResult = jobResultFuture.get();
- assertTrue(jobResult.isSuccess());
+ Assertions.assertThat(jobResult.isSuccess()).isTrue();
secondJobGraph.setJobID(new JobID());
miniCluster.submitJob(secondJobGraph).get();
jobResultFuture = miniCluster.requestJobResult(secondJobGraph.getJobID());
jobResult = jobResultFuture.get();
- assertTrue(jobResult.isSuccess());
+ Assertions.assertThat(jobResult.isSuccess()).isTrue();
}
}
- private CachedIntermediateDataSetCorruptedException
- getCachedIntermediateDataSetCorruptedException(JobResult jobResult) {
- assertTrue(jobResult.getSerializedThrowable().isPresent());
+ private ClusterDatasetCorruptedException getClusterDatasetCorruptedException(
+ JobResult jobResult) {
+ Assertions.assertThat(jobResult.getSerializedThrowable().isPresent()).isTrue();
Throwable throwable =
jobResult
.getSerializedThrowable()
.get()
.deserializeError(Thread.currentThread().getContextClassLoader());
while (throwable != null) {
- if (throwable instanceof CachedIntermediateDataSetCorruptedException) {
- return (CachedIntermediateDataSetCorruptedException) throwable;
+ if (throwable instanceof ClusterDatasetCorruptedException) {
+ return (ClusterDatasetCorruptedException) throwable;
}
throwable = throwable.getCause();
}
@@ -261,10 +260,10 @@ public class JobIntermediateDatasetReuseTest {
for (int i = index; i < index + 100; ++i) {
final int value = reader.next().getValue();
LOG.debug("Receiver({}) received {}", index, value);
- assertEquals(i, value);
+ Assertions.assertThat(value).isEqualTo(i);
}
- assertNull(reader.next());
+ Assertions.assertThat(reader.next()).isNull();
}
}
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/CachedDataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/CachedDataStream.java
new file mode 100644
index 00000000000..5717c3272de
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/CachedDataStream.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.datastream;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
+
+/**
+ * {@link CachedDataStream} represents a {@link DataStream} whose intermediate result will be cached
+ * at the first time when it is computed. And the cached intermediate result can be used in later
+ * job that using the same {@link CachedDataStream} to avoid re-computing the intermediate result.
+ *
+ * @param <T> The type of the elements in this stream.
+ */
+@PublicEvolving
+public class CachedDataStream<T> extends DataStream<T> {
+ /**
+ * Create a new {@link CachedDataStream} in the given execution environment that wrap the given
+ * physical transformation to indicates that the transformation should be cached.
+ *
+ * @param environment The StreamExecutionEnvironment
+ * @param transformation The physical transformation whose intermediate result should be cached.
+ */
+ public CachedDataStream(
+ StreamExecutionEnvironment environment, Transformation<T> transformation) {
+ super(
+ environment,
+ new CacheTransformation<>(
+ transformation, String.format("Cache: %s", transformation.getName())));
+
+ final CacheTransformation<T> t = (CacheTransformation<T>) this.getTransformation();
+ environment.registerCacheTransformation(t.getDatasetId(), t);
+ }
+
+ /**
+ * Invalidate the cache intermediate result of this DataStream to release the physical
+ * resources. Users are not required to invoke this method to release physical resources unless
+ * they want to. Cache will be recreated if it is used after invalidated.
+ */
+ public void invalidate() throws Exception {
+ final CacheTransformation<T> t = (CacheTransformation<T>) this.getTransformation();
+ environment.invalidateClusterDataset(t.getDatasetId());
+ }
+}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SideOutputDataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SideOutputDataStream.java
new file mode 100644
index 00000000000..fde78bd9641
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SideOutputDataStream.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.datastream;
+
+import org.apache.flink.annotation.Public;
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
+
+/**
+ * A {@link SideOutputDataStream} represents a {@link DataStream} that contains elements that are
+ * emitted from upstream into a side output with some tag.
+ *
+ * @param <T> The type of the elements in this stream.
+ */
+@Public
+public class SideOutputDataStream<T> extends DataStream<T> {
+ /**
+ * Creates a new {@link SideOutputDataStream} in the given execution environment.
+ *
+ * @param environment The StreamExecutionEnvironment
+ * @param transformation The SideOutputTransformation
+ */
+ public SideOutputDataStream(
+ StreamExecutionEnvironment environment, SideOutputTransformation<T> transformation) {
+ super(environment, transformation);
+ }
+
+ /**
+ * Caches the intermediate result of the transformation. Only support bounded streams and
+ * currently only block mode is supported. The cache is generated lazily at the first time the
+ * intermediate result is computed. The cache will be clear when {@link
+ * CachedDataStream#invalidate()} called or the {@link StreamExecutionEnvironment} close.
+ *
+ * @return CachedDataStream that can use in later job to reuse the cached intermediate result.
+ */
+ @PublicEvolving
+ public CachedDataStream<T> cache() {
+ return new CachedDataStream<>(this.environment, this.transformation);
+ }
+}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SingleOutputStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SingleOutputStreamOperator.java
index 2a87b5dc0cd..4141a31d85c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SingleOutputStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SingleOutputStreamOperator.java
@@ -399,7 +399,7 @@ public class SingleOutputStreamOperator<T> extends DataStream<T> {
* @see org.apache.flink.streaming.api.functions.ProcessFunction.Context#output(OutputTag,
* Object)
*/
- public <X> DataStream<X> getSideOutput(OutputTag<X> sideOutputTag) {
+ public <X> SideOutputDataStream<X> getSideOutput(OutputTag<X> sideOutputTag) {
sideOutputTag = clean(requireNonNull(sideOutputTag));
// make a defensive copy
@@ -417,7 +417,7 @@ public class SingleOutputStreamOperator<T> extends DataStream<T> {
SideOutputTransformation<X> sideOutputTransformation =
new SideOutputTransformation<>(this.getTransformation(), sideOutputTag);
- return new DataStream<>(this.getExecutionEnvironment(), sideOutputTransformation);
+ return new SideOutputDataStream<>(this.getExecutionEnvironment(), sideOutputTransformation);
}
/**
@@ -437,4 +437,22 @@ public class SingleOutputStreamOperator<T> extends DataStream<T> {
transformation.setDescription(description);
return this;
}
+
+ /**
+ * Cache the intermediate result of the transformation. Only support bounded streams and
+ * currently only block mode is supported. The cache is generated lazily at the first time the
+ * intermediate result is computed. The cache will be clear when {@link
+ * CachedDataStream#invalidate()} called or the {@link StreamExecutionEnvironment} close.
+ *
+ * @return CachedDataStream that can use in later job to reuse the cached intermediate result.
+ */
+ @PublicEvolving
+ public CachedDataStream<T> cache() {
+ if (!(this.transformation instanceof PhysicalTransformation)) {
+ throw new IllegalStateException(
+ "Cache can only be called with physical transformation or side output transformation");
+ }
+
+ return new CachedDataStream<>(this.environment, this.transformation);
+ }
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
index e298905a566..22b7d27be07 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
@@ -63,6 +63,7 @@ import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.configuration.StateChangelogOptions;
import org.apache.flink.configuration.UnmodifiableConfiguration;
+import org.apache.flink.core.execution.CacheSupportedPipelineExecutor;
import org.apache.flink.core.execution.DefaultExecutorServiceLoader;
import org.apache.flink.core.execution.DetachedJobExecutionResult;
import org.apache.flink.core.execution.JobClient;
@@ -72,6 +73,7 @@ import org.apache.flink.core.execution.PipelineExecutorFactory;
import org.apache.flink.core.execution.PipelineExecutorServiceLoader;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
+import org.apache.flink.runtime.scheduler.ClusterDatasetCorruptedException;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateBackendLoader;
@@ -98,6 +100,8 @@ import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.StreamGraphGenerator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.operators.collect.CollectResultIterator;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkException;
@@ -124,8 +128,10 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -142,7 +148,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
* @see org.apache.flink.streaming.api.environment.RemoteStreamEnvironment
*/
@Public
-public class StreamExecutionEnvironment {
+public class StreamExecutionEnvironment implements AutoCloseable {
private static final List<CollectResultIterator<?>> collectIterators = new ArrayList<>();
@@ -185,6 +191,8 @@ public class StreamExecutionEnvironment {
protected final List<Transformation<?>> transformations = new ArrayList<>();
+ private final Map<AbstractID, CacheTransformation<?>> cachedTransformations = new HashMap<>();
+
private long bufferTimeout = ExecutionOptions.BUFFER_TIMEOUT.defaultValue().toMillis();
protected boolean isChainingEnabled = true;
@@ -2027,7 +2035,7 @@ public class StreamExecutionEnvironment {
* @throws Exception which occurs during job execution.
*/
public JobExecutionResult execute() throws Exception {
- return execute(getStreamGraph());
+ return execute((String) null);
}
/**
@@ -2042,10 +2050,26 @@ public class StreamExecutionEnvironment {
* @throws Exception which occurs during job execution.
*/
public JobExecutionResult execute(String jobName) throws Exception {
- Preconditions.checkNotNull(jobName, "Streaming Job name should not be null.");
- final StreamGraph streamGraph = getStreamGraph();
- streamGraph.setJobName(jobName);
- return execute(streamGraph);
+ final List<Transformation<?>> originalTransformations = new ArrayList<>(transformations);
+ StreamGraph streamGraph = getStreamGraph();
+ if (jobName != null) {
+ streamGraph.setJobName(jobName);
+ }
+
+ try {
+ return execute(streamGraph);
+ } catch (Throwable t) {
+ Optional<ClusterDatasetCorruptedException> clusterDatasetCorruptedException =
+ ExceptionUtils.findThrowable(t, ClusterDatasetCorruptedException.class);
+ if (!clusterDatasetCorruptedException.isPresent()) {
+ throw t;
+ }
+
+ // Retry without cache if it is caused by corrupted cluster dataset.
+ invalidateCacheTransformations(originalTransformations);
+ streamGraph = getStreamGraph(originalTransformations);
+ return execute(streamGraph);
+ }
}
/**
@@ -2091,6 +2115,19 @@ public class StreamExecutionEnvironment {
}
}
+ private void invalidateCacheTransformations(List<Transformation<?>> transformations)
+ throws Exception {
+ for (Transformation<?> transformation : transformations) {
+ if (transformation == null) {
+ continue;
+ }
+ if (transformation instanceof CacheTransformation) {
+ invalidateClusterDataset(((CacheTransformation<?>) transformation).getDatasetId());
+ }
+ invalidateCacheTransformations(transformation.getInputs());
+ }
+ }
+
/**
* Register a {@link JobListener} in this environment. The {@link JobListener} will be notified
* on specific job status changed.
@@ -2156,22 +2193,10 @@ public class StreamExecutionEnvironment {
@Internal
public JobClient executeAsync(StreamGraph streamGraph) throws Exception {
checkNotNull(streamGraph, "StreamGraph cannot be null.");
- checkNotNull(
- configuration.get(DeploymentOptions.TARGET),
- "No execution.target specified in your configuration file.");
-
- final PipelineExecutorFactory executorFactory =
- executorServiceLoader.getExecutorFactory(configuration);
-
- checkNotNull(
- executorFactory,
- "Cannot find compatible factory for specified execution.target (=%s)",
- configuration.get(DeploymentOptions.TARGET));
+ final PipelineExecutor executor = getPipelineExecutor();
CompletableFuture<JobClient> jobClientFuture =
- executorFactory
- .getExecutor(configuration)
- .execute(streamGraph, configuration, userClassloader);
+ executor.execute(streamGraph, configuration, userClassloader);
try {
JobClient jobClient = jobClientFuture.get();
@@ -2213,13 +2238,32 @@ public class StreamExecutionEnvironment {
*/
@Internal
public StreamGraph getStreamGraph(boolean clearTransformations) {
- final StreamGraph streamGraph = getStreamGraphGenerator(transformations).generate();
+ final StreamGraph streamGraph = getStreamGraph(transformations);
if (clearTransformations) {
transformations.clear();
}
return streamGraph;
}
+ private StreamGraph getStreamGraph(List<Transformation<?>> transformations) {
+ synchronizeClusterDatasetStatus();
+ return getStreamGraphGenerator(transformations).generate();
+ }
+
+ private void synchronizeClusterDatasetStatus() {
+ if (cachedTransformations.isEmpty()) {
+ return;
+ }
+ Set<AbstractID> completedClusterDatasets =
+ listCompletedClusterDatasets().stream()
+ .map(AbstractID::new)
+ .collect(Collectors.toSet());
+ cachedTransformations.forEach(
+ (id, transformation) -> {
+ transformation.setCached(completedClusterDatasets.contains(id));
+ });
+ }
+
/**
* Generates a {@link StreamGraph} that consists of the given {@link Transformation
* transformations} and is configured with the configuration of this environment.
@@ -2601,4 +2645,69 @@ public class StreamExecutionEnvironment {
public List<Transformation<?>> getTransformations() {
return transformations;
}
+
+ @Internal
+ public <T> void registerCacheTransformation(
+ AbstractID intermediateDataSetID, CacheTransformation<T> t) {
+ cachedTransformations.put(intermediateDataSetID, t);
+ }
+
+ @Internal
+ public void invalidateClusterDataset(AbstractID datasetId) throws Exception {
+ if (!cachedTransformations.containsKey(datasetId)) {
+ throw new RuntimeException(
+ String.format("IntermediateDataset %s is not found", datasetId));
+ }
+ final PipelineExecutor executor = getPipelineExecutor();
+
+ if (!(executor instanceof CacheSupportedPipelineExecutor)) {
+ return;
+ }
+
+ ((CacheSupportedPipelineExecutor) executor)
+ .invalidateClusterDataset(datasetId, configuration, userClassloader)
+ .get();
+ cachedTransformations.get(datasetId).setCached(false);
+ }
+
+ protected Set<AbstractID> listCompletedClusterDatasets() {
+ try {
+ final PipelineExecutor executor = getPipelineExecutor();
+ if (!(executor instanceof CacheSupportedPipelineExecutor)) {
+ return Collections.emptySet();
+ }
+ return ((CacheSupportedPipelineExecutor) executor)
+ .listCompletedClusterDatasetIds(configuration, userClassloader)
+ .get();
+ } catch (Throwable e) {
+ return Collections.emptySet();
+ }
+ }
+
+ /**
+ * Close and clean up the execution environment. All the cached intermediate results will be
+ * released physically.
+ */
+ @Override
+ public void close() throws Exception {
+ for (AbstractID id : cachedTransformations.keySet()) {
+ invalidateClusterDataset(id);
+ }
+ }
+
+ private PipelineExecutor getPipelineExecutor() throws Exception {
+ checkNotNull(
+ configuration.get(DeploymentOptions.TARGET),
+ "No execution.target specified in your configuration file.");
+
+ final PipelineExecutorFactory executorFactory =
+ executorServiceLoader.getExecutorFactory(configuration);
+
+ checkNotNull(
+ executorFactory,
+ "Cannot find compatible factory for specified execution.target (=%s)",
+ configuration.get(DeploymentOptions.TARGET));
+
+ return executorFactory.getExecutor(configuration);
+ }
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
index 94ab343f855..704586badec 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
@@ -18,6 +18,7 @@
package org.apache.flink.streaming.api.graph;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.util.OutputTag;
@@ -73,6 +74,8 @@ public class StreamEdge implements Serializable {
private boolean supportsUnalignedCheckpoints = true;
+ private final IntermediateDataSetID intermediateDatasetIdToProduce;
+
public StreamEdge(
StreamNode sourceVertex,
StreamNode targetVertex,
@@ -88,7 +91,8 @@ public class StreamEdge implements Serializable {
outputPartitioner,
outputTag,
StreamExchangeMode.UNDEFINED,
- 0);
+ 0,
+ null);
}
public StreamEdge(
@@ -98,7 +102,8 @@ public class StreamEdge implements Serializable {
StreamPartitioner<?> outputPartitioner,
OutputTag outputTag,
StreamExchangeMode exchangeMode,
- int uniqueId) {
+ int uniqueId,
+ IntermediateDataSetID intermediateDatasetId) {
this(
sourceVertex,
@@ -108,7 +113,8 @@ public class StreamEdge implements Serializable {
outputPartitioner,
outputTag,
exchangeMode,
- uniqueId);
+ uniqueId,
+ intermediateDatasetId);
}
public StreamEdge(
@@ -119,7 +125,8 @@ public class StreamEdge implements Serializable {
StreamPartitioner<?> outputPartitioner,
OutputTag outputTag,
StreamExchangeMode exchangeMode,
- int uniqueId) {
+ int uniqueId,
+ IntermediateDataSetID intermediateDatasetId) {
this.sourceId = sourceVertex.getId();
this.targetId = targetVertex.getId();
@@ -131,6 +138,7 @@ public class StreamEdge implements Serializable {
this.sourceOperatorName = sourceVertex.getOperatorName();
this.targetOperatorName = targetVertex.getOperatorName();
this.exchangeMode = checkNotNull(exchangeMode);
+ this.intermediateDatasetIdToProduce = intermediateDatasetId;
this.edgeId =
sourceVertex
+ "_"
@@ -226,4 +234,8 @@ public class StreamEdge implements Serializable {
+ uniqueId
+ ')';
}
+
+ public IntermediateDataSetID getIntermediateDatasetIdToProduce() {
+ return intermediateDatasetIdToProduce;
+ }
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
index f1051a3f91e..8420539d798 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
@@ -37,6 +37,7 @@ import org.apache.flink.core.fs.Path;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.executiongraph.JobStatusHook;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobType;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
@@ -611,6 +612,14 @@ public class StreamGraph implements Pipeline {
}
public void addEdge(Integer upStreamVertexID, Integer downStreamVertexID, int typeNumber) {
+ addEdge(upStreamVertexID, downStreamVertexID, typeNumber, null);
+ }
+
+ public void addEdge(
+ Integer upStreamVertexID,
+ Integer downStreamVertexID,
+ int typeNumber,
+ IntermediateDataSetID intermediateDataSetId) {
addEdgeInternal(
upStreamVertexID,
downStreamVertexID,
@@ -618,7 +627,8 @@ public class StreamGraph implements Pipeline {
null,
new ArrayList<String>(),
null,
- null);
+ null,
+ intermediateDataSetId);
}
private void addEdgeInternal(
@@ -628,7 +638,8 @@ public class StreamGraph implements Pipeline {
StreamPartitioner<?> partitioner,
List<String> outputNames,
OutputTag outputTag,
- StreamExchangeMode exchangeMode) {
+ StreamExchangeMode exchangeMode,
+ IntermediateDataSetID intermediateDataSetId) {
if (virtualSideOutputNodes.containsKey(upStreamVertexID)) {
int virtualId = upStreamVertexID;
@@ -643,7 +654,8 @@ public class StreamGraph implements Pipeline {
partitioner,
null,
outputTag,
- exchangeMode);
+ exchangeMode,
+ intermediateDataSetId);
} else if (virtualPartitionNodes.containsKey(upStreamVertexID)) {
int virtualId = upStreamVertexID;
upStreamVertexID = virtualPartitionNodes.get(virtualId).f0;
@@ -658,7 +670,8 @@ public class StreamGraph implements Pipeline {
partitioner,
outputNames,
outputTag,
- exchangeMode);
+ exchangeMode,
+ intermediateDataSetId);
} else {
createActualEdge(
upStreamVertexID,
@@ -666,7 +679,8 @@ public class StreamGraph implements Pipeline {
typeNumber,
partitioner,
outputTag,
- exchangeMode);
+ exchangeMode,
+ intermediateDataSetId);
}
}
@@ -676,7 +690,8 @@ public class StreamGraph implements Pipeline {
int typeNumber,
StreamPartitioner<?> partitioner,
OutputTag outputTag,
- StreamExchangeMode exchangeMode) {
+ StreamExchangeMode exchangeMode,
+ IntermediateDataSetID intermediateDataSetId) {
StreamNode upstreamNode = getStreamNode(upStreamVertexID);
StreamNode downstreamNode = getStreamNode(downStreamVertexID);
@@ -728,7 +743,8 @@ public class StreamGraph implements Pipeline {
partitioner,
outputTag,
exchangeMode,
- uniqueId);
+ uniqueId,
+ intermediateDataSetId);
getStreamNode(edge.getSourceId()).addOutEdge(edge);
getStreamNode(edge.getTargetId()).addInEdge(edge);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index 0fd63c49192..85d635a9c8c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -49,6 +49,7 @@ import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionCheck
import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionInternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionStateBackend;
import org.apache.flink.streaming.api.transformations.BroadcastStateTransformation;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
import org.apache.flink.streaming.api.transformations.CoFeedbackTransformation;
import org.apache.flink.streaming.api.transformations.FeedbackTransformation;
import org.apache.flink.streaming.api.transformations.KeyedBroadcastStateTransformation;
@@ -69,6 +70,7 @@ import org.apache.flink.streaming.api.transformations.UnionTransformation;
import org.apache.flink.streaming.api.transformations.WithBoundedness;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.translators.BroadcastStateTransformationTranslator;
+import org.apache.flink.streaming.runtime.translators.CacheTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.KeyedBroadcastStateTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.LegacySinkTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.LegacySourceTransformationTranslator;
@@ -205,6 +207,7 @@ public class StreamGraphGenerator {
tmp.put(
KeyedBroadcastStateTransformation.class,
new KeyedBroadcastStateTransformationTranslator<>());
+ tmp.put(CacheTransformation.class, new CacheTransformationTranslator<>());
translatorMap = Collections.unmodifiableMap(tmp);
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
index 1fcae5b9136..68edc370b5d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.operators.ResourceSpec;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
@@ -90,6 +91,8 @@ public class StreamNode {
private final Map<Integer, StreamConfig.InputRequirement> inputRequirements = new HashMap<>();
+ private @Nullable IntermediateDataSetID consumeClusterDatasetId;
+
@VisibleForTesting
public StreamNode(
Integer id,
@@ -405,4 +408,14 @@ public class StreamNode {
public int hashCode() {
return id;
}
+
+ @Nullable
+ public IntermediateDataSetID getConsumeClusterDatasetId() {
+ return consumeClusterDatasetId;
+ }
+
+ public void setConsumeClusterDatasetId(
+ @Nullable IntermediateDataSetID consumeClusterDatasetId) {
+ this.consumeClusterDatasetId = consumeClusterDatasetId;
+ }
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 652ef88dcc0..be72cacff5c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -36,6 +36,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.InputOutputFormatContainer;
import org.apache.flink.runtime.jobgraph.InputOutputFormatVertex;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphUtils;
@@ -797,6 +798,10 @@ public class StreamingJobGraphGenerator {
jobVertex = new JobVertex(chainedNames.get(streamNodeId), jobVertexId, operatorIDPairs);
}
+ if (streamNode.getConsumeClusterDatasetId() != null) {
+ jobVertex.addIntermediateDataSetIdToConsume(streamNode.getConsumeClusterDatasetId());
+ }
+
for (OperatorCoordinator.Provider coordinatorProvider :
chainInfo.getCoordinatorProviders()) {
coordinatorSerializationFutures.add(
@@ -899,6 +904,13 @@ public class StreamingJobGraphGenerator {
}
}
}
+
+ // set the input config of the vertex if it consumes from cached intermediate dataset.
+ if (vertex.getConsumeClusterDatasetId() != null) {
+ config.setNumberOfNetworkInputs(1);
+ inputConfigs[0] = new StreamConfig.NetworkInputConfig(inputSerializers[0], 0);
+ }
+
config.setInputs(inputConfigs);
config.setTypeSerializerOut(vertex.getTypeSerializerOut());
@@ -1052,18 +1064,31 @@ public class StreamingJobGraphGenerator {
hasHybridResultPartition = true;
}
+ IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+ if (isPersistentIntermediateDataset(resultPartitionType, edge)) {
+ resultPartitionType = ResultPartitionType.BLOCKING_PERSISTENT;
+ intermediateDataSetID = edge.getIntermediateDatasetIdToProduce();
+ }
+
checkBufferTimeout(resultPartitionType, edge);
JobEdge jobEdge;
if (partitioner.isPointwise()) {
jobEdge =
downStreamVertex.connectNewDataSetAsInput(
- headVertex, DistributionPattern.POINTWISE, resultPartitionType);
+ headVertex,
+ DistributionPattern.POINTWISE,
+ resultPartitionType,
+ intermediateDataSetID);
} else {
jobEdge =
downStreamVertex.connectNewDataSetAsInput(
- headVertex, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+ headVertex,
+ DistributionPattern.ALL_TO_ALL,
+ resultPartitionType,
+ intermediateDataSetID);
}
+
// set strategy name so that web interface can show it.
jobEdge.setShipStrategyName(partitioner.toString());
jobEdge.setBroadcast(partitioner.isBroadcast());
@@ -1080,6 +1105,12 @@ public class StreamingJobGraphGenerator {
}
}
+ private boolean isPersistentIntermediateDataset(
+ ResultPartitionType resultPartitionType, StreamEdge edge) {
+ return resultPartitionType.isBlockingOrBlockingPersistentResultPartition()
+ && edge.getIntermediateDatasetIdToProduce() != null;
+ }
+
private void checkBufferTimeout(ResultPartitionType type, StreamEdge edge) {
long bufferTimeout = edge.getBufferTimeout();
if (!type.canBePipelinedConsumed()
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
new file mode 100644
index 00000000000..95c219f95a6
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.transformations;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.util.AbstractID;
+
+import org.apache.flink.shaded.guava30.com.google.common.collect.Lists;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * When in batch mode, the {@link CacheTransformation} represents the intermediate result of the
+ * upper stream should be cached when it is computed at the first time. And it consumes the cached
+ * intermediate result in later jobs. In stream mode, it has no affect.
+ *
+ * @param <T> The type of the elements in the cache intermediate result.
+ */
+@Internal
+public class CacheTransformation<T> extends Transformation<T> {
+ private final Transformation<T> transformationToCache;
+ private final AbstractID datasetId;
+ private boolean isCached;
+ /**
+ * Creates a new {@code Transformation} with the given name, output type and parallelism.
+ *
+ * @param name The name of the {@code Transformation}, this will be shown in Visualizations and
+ * the Log
+ */
+ public CacheTransformation(Transformation<T> transformationToCache, String name) {
+ super(name, transformationToCache.getOutputType(), transformationToCache.getParallelism());
+ this.transformationToCache = transformationToCache;
+
+ this.datasetId = new AbstractID();
+ this.isCached = false;
+ }
+
+ @Override
+ public List<Transformation<?>> getTransitivePredecessors() {
+ List<Transformation<?>> result = Lists.newArrayList();
+ result.add(this);
+ if (isCached) {
+ return result;
+ }
+ result.addAll(transformationToCache.getTransitivePredecessors());
+ return result;
+ }
+
+ @Override
+ public List<Transformation<?>> getInputs() {
+ if (isCached) {
+ return Collections.emptyList();
+ }
+ return Collections.singletonList(transformationToCache);
+ }
+
+ public AbstractID getDatasetId() {
+ return datasetId;
+ }
+
+ public Transformation<T> getTransformationToCache() {
+ return transformationToCache;
+ }
+
+ public void setCached(boolean cached) {
+ isCached = cached;
+ }
+
+ public boolean isCached() {
+ return isCached;
+ }
+}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/CacheTransformationTranslator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/CacheTransformationTranslator.java
new file mode 100644
index 00000000000..c55998ab7fd
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/CacheTransformationTranslator.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.translators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.streaming.api.graph.SimpleTransformationTranslator;
+import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+/** Translator for {@link CacheTransformationTranslator}. */
+@Internal
+public class CacheTransformationTranslator<OUT, T extends CacheTransformation<OUT>>
+ extends SimpleTransformationTranslator<OUT, T> {
+
+ public static final String CACHE_CONSUMER_OPERATOR_NAME = "CacheRead";
+ public static final String CACHE_PRODUCER_OPERATOR_NAME = "CacheWrite";
+
+ @Override
+ protected Collection<Integer> translateForBatchInternal(T transformation, Context context) {
+ if (!transformation.isCached()) {
+ final List<Transformation<?>> inputs = transformation.getInputs();
+ Preconditions.checkState(
+ inputs.size() == 1, "There could be only one transformation input to cache");
+ Transformation<?> input = inputs.get(0);
+ if (input instanceof PhysicalTransformation) {
+ return physicalTransformationProduceCache(transformation, context, input);
+ } else if (input instanceof SideOutputTransformation) {
+ return sideOutputTransformationProduceCache(
+ transformation, context, (SideOutputTransformation<?>) input);
+ } else {
+ throw new RuntimeException(
+ String.format("Unsupported transformation %s", input.getClass()));
+ }
+ } else {
+ return consumeCache(transformation, context);
+ }
+ }
+
+ @Override
+ protected Collection<Integer> translateForStreamingInternal(T transformation, Context context) {
+ if (transformation.isCached()) {
+ return consumeCache(transformation, context);
+ } else {
+ throw new RuntimeException(
+ "Producing cache IntermediateResult is not supported in streaming mode");
+ }
+ }
+
+ private Collection<Integer> sideOutputTransformationProduceCache(
+ T transformation, Context context, SideOutputTransformation<?> input) {
+ final StreamGraph streamGraph = context.getStreamGraph();
+ // SideOutput Transformation has only one input
+ final Transformation<?> physicalTransformation = input.getInputs().get(0);
+
+ final Collection<Integer> cacheNodeIds = context.getStreamNodeIds(physicalTransformation);
+
+ Preconditions.checkState(
+ cacheNodeIds.size() == 1, "We expect only one stream node for the input transform");
+
+ final Integer cacheNodeId = cacheNodeIds.iterator().next();
+
+ addCacheProduceNode(streamGraph, transformation, context, physicalTransformation);
+
+ final int virtualId = Transformation.getNewNodeId();
+ streamGraph.addVirtualSideOutputNode(cacheNodeId, virtualId, input.getOutputTag());
+ streamGraph.addEdge(
+ virtualId,
+ transformation.getId(),
+ 0,
+ new IntermediateDataSetID(transformation.getDatasetId()));
+ return Collections.singletonList(virtualId);
+ }
+
+ private List<Integer> physicalTransformationProduceCache(
+ T transformation, Context context, Transformation<?> input) {
+ final StreamGraph streamGraph = context.getStreamGraph();
+ final Collection<Integer> cachedNodeIds = context.getStreamNodeIds(input);
+
+ Preconditions.checkState(
+ cachedNodeIds.size() == 1,
+ "We expect only one stream node for the input transform");
+
+ final Integer cacheNodeId = cachedNodeIds.iterator().next();
+
+ addCacheProduceNode(streamGraph, transformation, context, input);
+
+ streamGraph.addEdge(
+ cacheNodeId,
+ transformation.getId(),
+ 0,
+ new IntermediateDataSetID(transformation.getDatasetId()));
+ return Collections.singletonList(cacheNodeId);
+ }
+
+ private void addCacheProduceNode(
+ StreamGraph streamGraph,
+ T cacheTransformation,
+ Context context,
+ Transformation<?> input) {
+ final SimpleOperatorFactory<OUT> operatorFactory =
+ SimpleOperatorFactory.of(new NoOpStreamOperator<>());
+ operatorFactory.setChainingStrategy(ChainingStrategy.HEAD);
+ streamGraph.addOperator(
+ cacheTransformation.getId(),
+ context.getSlotSharingGroup(),
+ cacheTransformation.getCoLocationGroupKey(),
+ operatorFactory,
+ cacheTransformation.getInputs().get(0).getOutputType(),
+ null,
+ CACHE_PRODUCER_OPERATOR_NAME);
+
+ streamGraph.setParallelism(cacheTransformation.getId(), input.getParallelism());
+ streamGraph.setMaxParallelism(cacheTransformation.getId(), input.getMaxParallelism());
+ }
+
+ private List<Integer> consumeCache(T transformation, Context context) {
+ final StreamGraph streamGraph = context.getStreamGraph();
+ final SimpleOperatorFactory<OUT> operatorFactory =
+ SimpleOperatorFactory.of(new IdentityStreamOperator<>());
+ final TypeInformation<OUT> outputType =
+ transformation.getTransformationToCache().getOutputType();
+ streamGraph.addLegacySource(
+ transformation.getId(),
+ context.getSlotSharingGroup(),
+ transformation.getCoLocationGroupKey(),
+ operatorFactory,
+ outputType,
+ outputType,
+ CACHE_CONSUMER_OPERATOR_NAME);
+ streamGraph.setParallelism(
+ transformation.getId(), transformation.getTransformationToCache().getParallelism());
+ streamGraph.setMaxParallelism(
+ transformation.getId(),
+ transformation.getTransformationToCache().getMaxParallelism());
+ final StreamNode streamNode = streamGraph.getStreamNode(transformation.getId());
+ streamNode.setConsumeClusterDatasetId(
+ new IntermediateDataSetID(transformation.getDatasetId()));
+ return Collections.singletonList(transformation.getId());
+ }
+
+ /**
+ * The {@link NoOpStreamOperator} acts as a dummy sink so that the upstream can produce the
+ * intermediate dataset to be cached.
+ *
+ * @param <T> The output type of the operator, which is the type of the cached intermediate
+ * dataset as well.
+ */
+ public static class NoOpStreamOperator<T> extends AbstractStreamOperator<T>
+ implements OneInputStreamOperator<T, T> {
+ private static final long serialVersionUID = 4517845269225218313L;
+
+ @Override
+ public void processElement(StreamRecord<T> element) throws Exception {
+ // do nothing
+ }
+ }
+
+ /**
+ * The {@link IdentityStreamOperator} acts as a dummy source to consume cached intermediate
+ * dataset.
+ *
+ * @param <T> The output type of the operator, which is the type of the cached intermediate *
+ * dataset as well.
+ */
+ public static class IdentityStreamOperator<T> extends AbstractStreamOperator<T>
+ implements OneInputStreamOperator<T, T> {
+ private static final long serialVersionUID = 4517845269225218313L;
+
+ @Override
+ public void processElement(StreamRecord<T> element) throws Exception {
+ output.collect(element);
+ }
+ }
+}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
index 7580642d420..3e52c036f7a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
@@ -19,6 +19,7 @@
package org.apache.flink.streaming.api.graph;
import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.operators.ResourceSpec;
import org.apache.flink.api.common.operators.SlotSharingGroup;
@@ -33,6 +34,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.jobgraph.SavepointConfigOptions;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.streaming.api.datastream.BroadcastStream;
+import org.apache.flink.streaming.api.datastream.CachedDataStream;
import org.apache.flink.streaming.api.datastream.ConnectedStreams;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.IterativeStream;
@@ -55,11 +57,13 @@ import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner;
@@ -67,9 +71,12 @@ import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.translators.CacheTransformationTranslator;
import org.apache.flink.streaming.util.NoOpIntMap;
import org.apache.flink.streaming.util.TestExpandingSink;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
import org.apache.flink.util.TestLogger;
import org.assertj.core.api.Assertions;
@@ -77,15 +84,17 @@ import org.hamcrest.Description;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -93,9 +102,6 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.iterableWithSize;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
/**
* Tests for {@link StreamGraphGenerator}. This only tests correct translation of split/select,
@@ -140,19 +146,19 @@ public class StreamGraphGeneratorTest extends TestLogger {
for (StreamNode node : sg.getStreamNodes()) {
switch (node.getOperatorName()) {
case "A":
- assertEquals(77L, node.getBufferTimeout());
+ Assertions.assertThat(77L).isEqualTo(node.getBufferTimeout());
break;
case "B":
- assertEquals(0L, node.getBufferTimeout());
+ Assertions.assertThat(node.getBufferTimeout()).isEqualTo(0L);
break;
case "C":
- assertEquals(12L, node.getBufferTimeout());
+ Assertions.assertThat(node.getBufferTimeout()).isEqualTo(12L);
break;
case "D":
- assertEquals(77L, node.getBufferTimeout());
+ Assertions.assertThat(node.getBufferTimeout()).isEqualTo(77L);
break;
default:
- assertTrue(node.getOperator() instanceof StreamSource);
+ Assertions.assertThat(node.getOperator()).isInstanceOf(StreamSource.class);
}
}
}
@@ -198,29 +204,47 @@ public class StreamGraphGeneratorTest extends TestLogger {
StreamGraph graph = env.getStreamGraph();
// rebalanceMap
- assertTrue(
- graph.getStreamNode(rebalanceMap.getId()).getInEdges().get(0).getPartitioner()
- instanceof RebalancePartitioner);
+ Assertions.assertThat(
+ graph.getStreamNode(rebalanceMap.getId())
+ .getInEdges()
+ .get(0)
+ .getPartitioner())
+ .isInstanceOf(RebalancePartitioner.class);
// verify that only last partitioning takes precedence
- assertTrue(
- graph.getStreamNode(broadcastMap.getId()).getInEdges().get(0).getPartitioner()
- instanceof BroadcastPartitioner);
- assertEquals(
- rebalanceMap.getId(),
- graph.getSourceVertex(graph.getStreamNode(broadcastMap.getId()).getInEdges().get(0))
- .getId());
+ Assertions.assertThat(
+ graph.getStreamNode(broadcastMap.getId())
+ .getInEdges()
+ .get(0)
+ .getPartitioner())
+ .isInstanceOf(BroadcastPartitioner.class);
+ Assertions.assertThat(
+ graph.getSourceVertex(
+ graph.getStreamNode(broadcastMap.getId())
+ .getInEdges()
+ .get(0))
+ .getId())
+ .isEqualTo(rebalanceMap.getId());
// verify that partitioning in unions is preserved
- assertTrue(
- graph.getStreamNode(broadcastOperator.getId()).getOutEdges().get(0).getPartitioner()
- instanceof BroadcastPartitioner);
- assertTrue(
- graph.getStreamNode(globalOperator.getId()).getOutEdges().get(0).getPartitioner()
- instanceof GlobalPartitioner);
- assertTrue(
- graph.getStreamNode(shuffleOperator.getId()).getOutEdges().get(0).getPartitioner()
- instanceof ShufflePartitioner);
+ Assertions.assertThat(
+ graph.getStreamNode(broadcastOperator.getId())
+ .getOutEdges()
+ .get(0)
+ .getPartitioner())
+ .isInstanceOf(BroadcastPartitioner.class);
+ Assertions.assertThat(
+ graph.getStreamNode(globalOperator.getId())
+ .getOutEdges()
+ .get(0)
+ .getPartitioner())
+ .isInstanceOf(GlobalPartitioner.class);
+ Assertions.assertThat(
+ graph.getStreamNode(shuffleOperator.getId())
+ .getOutEdges()
+ .get(0)
+ .getPartitioner())
+ .isInstanceOf(ShufflePartitioner.class);
}
@Test
@@ -238,8 +262,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
env.getStreamGraph();
- assertTrue(udfOperator instanceof AbstractUdfStreamOperator);
- assertEquals(BasicTypeInfo.INT_TYPE_INFO, function.getTypeInformation());
+ Assertions.assertThat(udfOperator).isInstanceOf(AbstractUdfStreamOperator.class);
+ Assertions.assertThat(function.getTypeInformation()).isEqualTo(BasicTypeInfo.INT_TYPE_INFO);
}
/**
@@ -265,8 +289,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
env.getStreamGraph();
- assertEquals(
- BasicTypeInfo.INT_TYPE_INFO, outputTypeConfigurableOperation.getTypeInformation());
+ Assertions.assertThat(outputTypeConfigurableOperation.getTypeInformation())
+ .isEqualTo(BasicTypeInfo.INT_TYPE_INFO);
}
@Test
@@ -291,8 +315,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
env.getStreamGraph();
- assertEquals(
- BasicTypeInfo.INT_TYPE_INFO, outputTypeConfigurableOperation.getTypeInformation());
+ Assertions.assertThat(outputTypeConfigurableOperation.getTypeInformation())
+ .isEqualTo(BasicTypeInfo.INT_TYPE_INFO);
}
@Test
@@ -317,15 +341,18 @@ public class StreamGraphGeneratorTest extends TestLogger {
.addInput(source3.getTransformation()));
StreamGraph streamGraph = env.getStreamGraph();
- assertEquals(4, streamGraph.getStreamNodes().size());
-
- assertEquals(1, streamGraph.getStreamEdges(source1.getId(), transform.getId()).size());
- assertEquals(1, streamGraph.getStreamEdges(source2.getId(), transform.getId()).size());
- assertEquals(1, streamGraph.getStreamEdges(source3.getId(), transform.getId()).size());
- assertEquals(1, streamGraph.getStreamEdges(source1.getId()).size());
- assertEquals(1, streamGraph.getStreamEdges(source2.getId()).size());
- assertEquals(1, streamGraph.getStreamEdges(source3.getId()).size());
- assertEquals(0, streamGraph.getStreamEdges(transform.getId()).size());
+ Assertions.assertThat(streamGraph.getStreamNodes().size()).isEqualTo(4);
+
+ Assertions.assertThat(streamGraph.getStreamEdges(source1.getId(), transform.getId()).size())
+ .isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(source2.getId(), transform.getId()).size())
+ .isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(source3.getId(), transform.getId()).size())
+ .isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(source1.getId()).size()).isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(source2.getId()).size()).isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(source3.getId()).size()).isEqualTo(1);
+ Assertions.assertThat(streamGraph.getStreamEdges(transform.getId()).size()).isEqualTo(0);
}
@Test
@@ -359,7 +386,7 @@ public class StreamGraphGeneratorTest extends TestLogger {
DataStream<Long> map4 = map3.rescale().map(l -> l).setParallelism(1337);
StreamGraph streamGraph = env.getStreamGraph();
- assertEquals(7, streamGraph.getStreamNodes().size());
+ Assertions.assertThat(streamGraph.getStreamNodes().size()).isEqualTo(7);
// forward
assertThat(edge(streamGraph, source1, map1), supportsUnalignedCheckpoints(false));
@@ -403,7 +430,7 @@ public class StreamGraphGeneratorTest extends TestLogger {
});
StreamGraph streamGraph = env.getStreamGraph();
- assertEquals(4, streamGraph.getStreamNodes().size());
+ Assertions.assertThat(streamGraph.getStreamNodes().size()).isEqualTo(4);
// single broadcast
assertThat(edge(streamGraph, source1, map1), supportsUnalignedCheckpoints(false));
@@ -482,8 +509,9 @@ public class StreamGraphGeneratorTest extends TestLogger {
StreamNode keyedResult1Node = graph.getStreamNode(keyedResult1.getId());
StreamNode keyedResult2Node = graph.getStreamNode(keyedResult2.getId());
- assertEquals(globalMaxParallelism, keyedResult1Node.getMaxParallelism());
- assertEquals(keyedResult2MaxParallelism, keyedResult2Node.getMaxParallelism());
+ Assertions.assertThat(keyedResult1Node.getMaxParallelism()).isEqualTo(globalMaxParallelism);
+ Assertions.assertThat(keyedResult2Node.getMaxParallelism())
+ .isEqualTo(keyedResult2MaxParallelism);
}
/**
@@ -528,8 +556,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
StreamNode keyedResult3Node = graph.getStreamNode(keyedResult3.getId());
StreamNode keyedResult4Node = graph.getStreamNode(keyedResult4.getId());
- assertEquals(maxParallelism, keyedResult3Node.getMaxParallelism());
- assertEquals(maxParallelism, keyedResult4Node.getMaxParallelism());
+ Assertions.assertThat(keyedResult3Node.getMaxParallelism()).isEqualTo(maxParallelism);
+ Assertions.assertThat(keyedResult4Node.getMaxParallelism()).isEqualTo(maxParallelism);
}
/** Tests that the max parallelism is properly set for connected streams. */
@@ -602,15 +630,14 @@ public class StreamGraphGeneratorTest extends TestLogger {
StreamGraph streamGraph = env.getStreamGraph();
for (Tuple2<StreamNode, StreamNode> iterationPair :
streamGraph.getIterationSourceSinkPairs()) {
- assertNotNull(iterationPair.f0.getCoLocationGroup());
- assertEquals(
- iterationPair.f0.getCoLocationGroup(), iterationPair.f1.getCoLocationGroup());
+ Assertions.assertThat(iterationPair.f0.getCoLocationGroup()).isNotNull();
+ Assertions.assertThat(iterationPair.f1.getCoLocationGroup())
+ .isEqualTo(iterationPair.f0.getCoLocationGroup());
- assertEquals(
- StreamGraphGenerator.DEFAULT_SLOT_SHARING_GROUP,
- iterationPair.f0.getSlotSharingGroup());
- assertEquals(
- iterationPair.f0.getSlotSharingGroup(), iterationPair.f1.getSlotSharingGroup());
+ Assertions.assertThat(iterationPair.f0.getSlotSharingGroup())
+ .isEqualTo(StreamGraphGenerator.DEFAULT_SLOT_SHARING_GROUP);
+ Assertions.assertThat(iterationPair.f1.getSlotSharingGroup())
+ .isEqualTo(iterationPair.f0.getSlotSharingGroup());
final ResourceSpec sourceMinResources = iterationPair.f0.getMinResources();
final ResourceSpec sinkMinResources = iterationPair.f1.getMinResources();
@@ -642,9 +669,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
Collection<StreamNode> streamNodes = streamGraph.getStreamNodes();
for (StreamNode streamNode : streamNodes) {
- assertEquals(
- StreamGraphGenerator.DEFAULT_SLOT_SHARING_GROUP,
- streamNode.getSlotSharingGroup());
+ Assertions.assertThat(streamNode.getSlotSharingGroup())
+ .isEqualTo(StreamGraphGenerator.DEFAULT_SLOT_SHARING_GROUP);
}
}
@@ -796,7 +822,7 @@ public class StreamGraphGeneratorTest extends TestLogger {
is(ResourceProfile.fromResources(3, 300)));
}
- @Test(expected = IllegalArgumentException.class)
+ @Test
public void testConflictSlotSharingGroup() {
final SlotSharingGroup ssg =
SlotSharingGroup.newBuilder("ssg").setCpuCores(1).setTaskHeapMemoryMB(100).build();
@@ -810,7 +836,8 @@ public class StreamGraphGeneratorTest extends TestLogger {
.addSink(new DiscardingSink<>())
.slotSharingGroup(ssgConflict);
- env.getStreamGraph();
+ Assertions.assertThatThrownBy(env::getStreamGraph)
+ .isInstanceOf(IllegalArgumentException.class);
}
@Test
@@ -875,8 +902,132 @@ public class StreamGraphGeneratorTest extends TestLogger {
.forEach(
node -> {
if (!node.getOperatorName().startsWith("Source")) {
- assertEquals(2, node.getParallelism());
+ Assertions.assertThat(node.getParallelism()).isEqualTo(2);
+ }
+ });
+ }
+
+ @Test
+ public void testCacheInStreamModeThrowsException() {
+ final TestingStreamExecutionEnvironment env = new TestingStreamExecutionEnvironment();
+ env.setRuntimeMode(RuntimeExecutionMode.STREAMING);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+ final int upstreamParallelism = 3;
+ CachedDataStream<Integer> cachedStream =
+ source.keyBy(i -> i)
+ .reduce(Integer::sum)
+ .setParallelism(upstreamParallelism)
+ .cache();
+ cachedStream.print();
+
+ Assertions.assertThatThrownBy(env::getStreamGraph).isInstanceOf(RuntimeException.class);
+ }
+
+ @Test
+ public void testCacheTransformation() {
+ final TestingStreamExecutionEnvironment env = new TestingStreamExecutionEnvironment();
+ env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+ final int upstreamParallelism = 3;
+ CachedDataStream<Integer> cachedStream =
+ source.keyBy(i -> i)
+ .reduce(Integer::sum)
+ .setParallelism(upstreamParallelism)
+ .cache();
+ Assertions.assertThat(cachedStream.getTransformation())
+ .isInstanceOf(CacheTransformation.class);
+ CacheTransformation<Integer> cacheTransformation =
+ (CacheTransformation<Integer>) cachedStream.getTransformation();
+
+ cachedStream.print();
+ final StreamGraph streamGraph = env.getStreamGraph();
+
+ verifyCacheProduceNode(upstreamParallelism, cacheTransformation, streamGraph, null);
+
+ env.addCompletedClusterDatasetIds(cacheTransformation.getDatasetId());
+ cachedStream.print();
+
+ verifyCacheConsumeNode(env, upstreamParallelism, cacheTransformation);
+ }
+
+ @Test
+ public void testCacheSideOutput() {
+ final TestingStreamExecutionEnvironment env = new TestingStreamExecutionEnvironment();
+ env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+
+ final int upstreamParallelism = 2;
+ SingleOutputStreamOperator<Integer> stream =
+ env.fromElements(1, 2, 3).map(i -> i).setParallelism(upstreamParallelism);
+ final DataStream<Integer> sideOutputCache =
+ stream.getSideOutput(new OutputTag<Integer>("1") {}).cache();
+ Assertions.assertThat(sideOutputCache.getTransformation())
+ .isInstanceOf(CacheTransformation.class);
+ final CacheTransformation<Integer> cacheTransformation =
+ (CacheTransformation<Integer>) sideOutputCache.getTransformation();
+ sideOutputCache.print();
+
+ final StreamGraph streamGraph = env.getStreamGraph();
+ verifyCacheProduceNode(upstreamParallelism, cacheTransformation, streamGraph, "1");
+
+ env.addCompletedClusterDatasetIds(cacheTransformation.getDatasetId());
+ sideOutputCache.print();
+
+ verifyCacheConsumeNode(env, upstreamParallelism, cacheTransformation);
+ }
+
+ private void verifyCacheProduceNode(
+ int upstreamParallelism,
+ CacheTransformation<Integer> cacheTransformation,
+ StreamGraph streamGraph,
+ String expectedTagId) {
+ Assertions.assertThat(streamGraph.getStreamNodes())
+ .anyMatch(
+ node -> {
+ if (!CacheTransformationTranslator.CACHE_PRODUCER_OPERATOR_NAME.equals(
+ node.getOperatorName())) {
+ return false;
+ }
+
+ Assertions.assertThat(node.getParallelism())
+ .isEqualTo(upstreamParallelism);
+ Assertions.assertThat(node.getInEdges().size()).isEqualTo(1);
+ final StreamEdge inEdge = node.getInEdges().get(0);
+ Assertions.assertThat(inEdge.getPartitioner())
+ .isInstanceOf(ForwardPartitioner.class);
+ if (expectedTagId != null) {
+ Assertions.assertThat(inEdge.getOutputTag().getId())
+ .isEqualTo(expectedTagId);
+ }
+
+ Assertions.assertThat(inEdge.getIntermediateDatasetIdToProduce())
+ .isNotNull();
+ Assertions.assertThat(
+ new AbstractID(
+ inEdge.getIntermediateDatasetIdToProduce()))
+ .isEqualTo(cacheTransformation.getDatasetId());
+ return true;
+ });
+ }
+
+ private void verifyCacheConsumeNode(
+ StreamExecutionEnvironment env,
+ int upstreamParallelism,
+ CacheTransformation<Integer> cacheTransformation) {
+ Assertions.assertThat(env.getStreamGraph().getStreamNodes())
+ .anyMatch(
+ node -> {
+ if (!CacheTransformationTranslator.CACHE_CONSUMER_OPERATOR_NAME.equals(
+ node.getOperatorName())) {
+ return false;
}
+
+ Assertions.assertThat(node.getParallelism())
+ .isEqualTo(upstreamParallelism);
+ Assertions.assertThat(new AbstractID(node.getConsumeClusterDatasetId()))
+ .isEqualTo(cacheTransformation.getDatasetId());
+ return true;
});
}
@@ -1069,4 +1220,17 @@ public class StreamGraphGeneratorTest extends TestLogger {
throw new UnsupportedOperationException();
}
}
+
+ private static class TestingStreamExecutionEnvironment extends StreamExecutionEnvironment {
+ Set<AbstractID> completedClusterDatasetIds = new HashSet<>();
+
+ public void addCompletedClusterDatasetIds(AbstractID id) {
+ completedClusterDatasetIds.add(id);
+ }
+
+ @Override
+ public Set<AbstractID> listCompletedClusterDatasets() {
+ return new HashSet<>(completedClusterDatasetIds);
+ }
+ }
}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 0d7b8513ee2..b38046bd7d2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -49,6 +49,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.InputOutputFormatContainer;
import org.apache.flink.runtime.jobgraph.InputOutputFormatVertex;
+import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobType;
import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -60,6 +61,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.CachedDataStream;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
@@ -82,6 +84,7 @@ import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.YieldingOperatorFactory;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
@@ -92,6 +95,7 @@ import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.SourceOperatorStreamTask;
import org.apache.flink.streaming.util.TestAnyModeReadingStreamOperator;
+import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.TestLogger;
@@ -99,10 +103,10 @@ import org.apache.flink.util.TestLogger;
import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
import org.assertj.core.api.Assertions;
+import org.assertj.core.data.Offset;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
-import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import java.lang.reflect.Method;
import java.util.ArrayList;
@@ -110,23 +114,19 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
import static org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator.areOperatorsChainable;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
/** Tests for {@link StreamingJobGraphGenerator}. */
@SuppressWarnings("serial")
@@ -176,19 +176,17 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = streamGraph.getJobGraph();
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(2, jobGraph.getNumberOfVertices());
- assertEquals(1, verticesSorted.get(0).getParallelism());
- assertEquals(1, verticesSorted.get(1).getParallelism());
+ Assertions.assertThat(jobGraph.getNumberOfVertices()).isEqualTo(2);
+ Assertions.assertThat(verticesSorted.get(0).getParallelism()).isEqualTo(1);
+ Assertions.assertThat(verticesSorted.get(1).getParallelism()).isEqualTo(1);
JobVertex sourceVertex = verticesSorted.get(0);
JobVertex mapSinkVertex = verticesSorted.get(1);
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- sourceVertex.getProducedDataSets().get(0).getResultType());
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- mapSinkVertex.getInputs().get(0).getSource().getResultType());
+ Assertions.assertThat(sourceVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
+ Assertions.assertThat(mapSinkVertex.getInputs().get(0).getSource().getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
}
/**
@@ -200,23 +198,28 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.fromElements(0).print();
StreamGraph streamGraph = env.getStreamGraph();
- assertFalse(
- "Checkpointing enabled",
- streamGraph.getCheckpointConfig().isCheckpointingEnabled());
+ Assertions.assertThat(streamGraph.getCheckpointConfig().isCheckpointingEnabled())
+ .withFailMessage("Checkpointing enabled")
+ .isFalse();
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
JobCheckpointingSettings snapshottingSettings = jobGraph.getCheckpointingSettings();
- assertEquals(
- Long.MAX_VALUE,
- snapshottingSettings
- .getCheckpointCoordinatorConfiguration()
- .getCheckpointInterval());
- assertFalse(snapshottingSettings.getCheckpointCoordinatorConfiguration().isExactlyOnce());
+ Assertions.assertThat(
+ snapshottingSettings
+ .getCheckpointCoordinatorConfiguration()
+ .getCheckpointInterval())
+ .isEqualTo(Long.MAX_VALUE);
+ Assertions.assertThat(
+ snapshottingSettings
+ .getCheckpointCoordinatorConfiguration()
+ .isExactlyOnce())
+ .isFalse();
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
StreamConfig streamConfig = new StreamConfig(verticesSorted.get(0).getConfiguration());
- assertEquals(CheckpointingMode.AT_LEAST_ONCE, streamConfig.getCheckpointMode());
+ Assertions.assertThat(streamConfig.getCheckpointMode())
+ .isEqualTo(CheckpointingMode.AT_LEAST_ONCE);
}
@Test
@@ -224,17 +227,18 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.fromElements(0).print();
StreamGraph streamGraph = env.getStreamGraph();
- assertFalse(
- "Checkpointing enabled",
- streamGraph.getCheckpointConfig().isCheckpointingEnabled());
+ Assertions.assertThat(streamGraph.getCheckpointConfig().isCheckpointingEnabled())
+ .withFailMessage("Checkpointing enabled")
+ .isFalse();
env.getCheckpointConfig().enableUnalignedCheckpoints(true);
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
StreamConfig streamConfig = new StreamConfig(verticesSorted.get(0).getConfiguration());
- assertEquals(CheckpointingMode.AT_LEAST_ONCE, streamConfig.getCheckpointMode());
- assertFalse(streamConfig.isUnalignedCheckpointsEnabled());
+ Assertions.assertThat(streamConfig.getCheckpointMode())
+ .isEqualTo(CheckpointingMode.AT_LEAST_ONCE);
+ Assertions.assertThat(streamConfig.isUnalignedCheckpointsEnabled()).isFalse();
}
@Test
@@ -249,8 +253,9 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
StreamConfig streamConfig = new StreamConfig(verticesSorted.get(0).getConfiguration());
- assertEquals(CheckpointingMode.AT_LEAST_ONCE, streamConfig.getCheckpointMode());
- assertFalse(streamConfig.isUnalignedCheckpointsEnabled());
+ Assertions.assertThat(streamConfig.getCheckpointMode())
+ .isEqualTo(CheckpointingMode.AT_LEAST_ONCE);
+ Assertions.assertThat(streamConfig.isUnalignedCheckpointsEnabled()).isFalse();
}
@Test
@@ -292,12 +297,10 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobVertex sourceVertex = verticesSorted.get(0);
JobVertex mapPrintVertex = verticesSorted.get(1);
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- sourceVertex.getProducedDataSets().get(0).getResultType());
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- mapPrintVertex.getInputs().get(0).getSource().getResultType());
+ Assertions.assertThat(sourceVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
+ Assertions.assertThat(mapPrintVertex.getInputs().get(0).getSource().getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
StreamConfig sourceConfig = new StreamConfig(sourceVertex.getConfiguration());
StreamConfig mapConfig = new StreamConfig(mapPrintVertex.getConfiguration());
@@ -305,14 +308,14 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
mapConfig.getTransitiveChainedTaskConfigs(getClass().getClassLoader());
StreamConfig printConfig = chainedConfigs.values().iterator().next();
- assertTrue(sourceConfig.isChainStart());
- assertTrue(sourceConfig.isChainEnd());
+ Assertions.assertThat(sourceConfig.isChainStart()).isTrue();
+ Assertions.assertThat(sourceConfig.isChainEnd()).isTrue();
- assertTrue(mapConfig.isChainStart());
- assertFalse(mapConfig.isChainEnd());
+ Assertions.assertThat(mapConfig.isChainStart()).isTrue();
+ Assertions.assertThat(mapConfig.isChainEnd()).isFalse();
- assertFalse(printConfig.isChainStart());
- assertTrue(printConfig.isChainEnd());
+ Assertions.assertThat(printConfig.isChainStart()).isFalse();
+ Assertions.assertThat(printConfig.isChainEnd()).isTrue();
}
@Test
@@ -336,7 +339,8 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
- assertEquals(2, jobGraph.getVerticesAsArray()[0].getOperatorCoordinators().size());
+ Assertions.assertThat(jobGraph.getVerticesAsArray()[0].getOperatorCoordinators().size())
+ .isEqualTo(2);
}
/**
@@ -418,11 +422,14 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
jobGraph.getVerticesSortedTopologicallyFromSources().get(0);
JobVertex reduceSinkVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
- assertTrue(
- sourceMapFilterVertex
- .getMinResources()
- .equals(resource3.merge(resource2).merge(resource1)));
- assertTrue(reduceSinkVertex.getPreferredResources().equals(resource4.merge(resource5)));
+ Assertions.assertThat(
+ sourceMapFilterVertex
+ .getMinResources()
+ .equals(resource3.merge(resource2).merge(resource1)))
+ .isTrue();
+ Assertions.assertThat(
+ reduceSinkVertex.getPreferredResources().equals(resource4.merge(resource5)))
+ .isTrue();
}
/**
@@ -498,15 +505,19 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
for (JobVertex jobVertex : jobGraph.getVertices()) {
if (jobVertex.getName().contains("test_source")) {
- assertTrue(jobVertex.getMinResources().equals(resource1));
+ Assertions.assertThat(jobVertex.getMinResources().equals(resource1)).isTrue();
} else if (jobVertex.getName().contains("Iteration_Source")) {
- assertTrue(jobVertex.getPreferredResources().equals(resource2));
+ Assertions.assertThat(jobVertex.getPreferredResources().equals(resource2)).isTrue();
} else if (jobVertex.getName().contains("test_flatMap")) {
- assertTrue(jobVertex.getMinResources().equals(resource3.merge(resource4)));
+ Assertions.assertThat(
+ jobVertex.getMinResources().equals(resource3.merge(resource4)))
+ .isTrue();
} else if (jobVertex.getName().contains("Iteration_Tail")) {
- assertTrue(jobVertex.getPreferredResources().equals(ResourceSpec.DEFAULT));
+ Assertions.assertThat(
+ jobVertex.getPreferredResources().equals(ResourceSpec.DEFAULT))
+ .isTrue();
} else if (jobVertex.getName().contains("test_sink")) {
- assertTrue(jobVertex.getMinResources().equals(resource5));
+ Assertions.assertThat(jobVertex.getMinResources().equals(resource5)).isTrue();
}
}
}
@@ -529,10 +540,10 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamGraph streamGraph = env.getStreamGraph();
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
- assertEquals(1, jobGraph.getNumberOfVertices());
+ Assertions.assertThat(jobGraph.getNumberOfVertices()).isEqualTo(1);
JobVertex jobVertex = jobGraph.getVertices().iterator().next();
- assertTrue(jobVertex instanceof InputOutputFormatVertex);
+ Assertions.assertThat(jobVertex instanceof InputOutputFormatVertex).isTrue();
InputOutputFormatContainer formatContainer =
new InputOutputFormatContainer(
@@ -542,8 +553,8 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
formatContainer.getInputFormats();
Map<OperatorID, UserCodeWrapper<? extends OutputFormat<?>>> outputFormats =
formatContainer.getOutputFormats();
- assertEquals(1, inputFormats.size());
- assertEquals(2, outputFormats.size());
+ Assertions.assertThat(inputFormats.size()).isEqualTo(1);
+ Assertions.assertThat(outputFormats.size()).isEqualTo(2);
Map<String, OperatorID> nameToOperatorIds = new HashMap<>();
StreamConfig headConfig = new StreamConfig(jobVertex.getConfiguration());
@@ -558,15 +569,15 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
InputFormat<?, ?> sourceFormat =
inputFormats.get(nameToOperatorIds.get("Source: source")).getUserCodeObject();
- assertTrue(sourceFormat instanceof TypeSerializerInputFormat);
+ Assertions.assertThat(sourceFormat instanceof TypeSerializerInputFormat).isTrue();
OutputFormat<?> sinkFormat1 =
outputFormats.get(nameToOperatorIds.get("Sink: sink1")).getUserCodeObject();
- assertTrue(sinkFormat1 instanceof DiscardingOutputFormat);
+ Assertions.assertThat(sinkFormat1 instanceof DiscardingOutputFormat).isTrue();
OutputFormat<?> sinkFormat2 =
outputFormats.get(nameToOperatorIds.get("Sink: sink2")).getUserCodeObject();
- assertTrue(sinkFormat2 instanceof DiscardingOutputFormat);
+ Assertions.assertThat(sinkFormat2 instanceof DiscardingOutputFormat).isTrue();
}
@Test
@@ -582,20 +593,21 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamGraph streamGraph = env.getStreamGraph();
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
// There should be only one job vertex.
- assertEquals(1, jobGraph.getNumberOfVertices());
+ Assertions.assertThat(jobGraph.getNumberOfVertices()).isEqualTo(1);
JobVertex jobVertex = jobGraph.getVerticesAsArray()[0];
List<SerializedValue<OperatorCoordinator.Provider>> coordinatorProviders =
jobVertex.getOperatorCoordinators();
// There should be only one coordinator provider.
- assertEquals(1, coordinatorProviders.size());
+ Assertions.assertThat(coordinatorProviders.size()).isEqualTo(1);
// The invokable class should be SourceOperatorStreamTask.
final ClassLoader classLoader = getClass().getClassLoader();
- assertEquals(SourceOperatorStreamTask.class, jobVertex.getInvokableClass(classLoader));
+ Assertions.assertThat(jobVertex.getInvokableClass(classLoader))
+ .isEqualTo(SourceOperatorStreamTask.class);
StreamOperatorFactory operatorFactory =
new StreamConfig(jobVertex.getConfiguration())
.getStreamOperatorFactory(classLoader);
- assertTrue(operatorFactory instanceof SourceOperatorFactory);
+ Assertions.assertThat(operatorFactory instanceof SourceOperatorFactory).isTrue();
}
/** Test setting exchange mode to {@link StreamExchangeMode#PIPELINED}. */
@@ -627,15 +639,14 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(2, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(2);
// it can be chained with PIPELINED exchange mode
JobVertex sourceAndMapVertex = verticesSorted.get(0);
// PIPELINED exchange mode is translated into PIPELINED_BOUNDED result partition
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- sourceAndMapVertex.getProducedDataSets().get(0).getResultType());
+ Assertions.assertThat(sourceAndMapVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
}
/** Test setting exchange mode to {@link StreamExchangeMode#BATCH}. */
@@ -669,19 +680,17 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(3, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(3);
// it can not be chained with BATCH exchange mode
JobVertex sourceVertex = verticesSorted.get(0);
JobVertex mapVertex = verticesSorted.get(1);
// BATCH exchange mode is translated into BLOCKING result partition
- assertEquals(
- ResultPartitionType.BLOCKING,
- sourceVertex.getProducedDataSets().get(0).getResultType());
- assertEquals(
- ResultPartitionType.BLOCKING,
- mapVertex.getProducedDataSets().get(0).getResultType());
+ Assertions.assertThat(sourceVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.BLOCKING);
+ Assertions.assertThat(mapVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.BLOCKING);
}
/** Test setting exchange mode to {@link StreamExchangeMode#UNDEFINED}. */
@@ -713,15 +722,14 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(2, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(2);
// it can be chained with UNDEFINED exchange mode
JobVertex sourceAndMapVertex = verticesSorted.get(0);
// UNDEFINED exchange mode is translated into PIPELINED_BOUNDED result partition by default
- assertEquals(
- ResultPartitionType.PIPELINED_BOUNDED,
- sourceAndMapVertex.getProducedDataSets().get(0).getResultType());
+ Assertions.assertThat(sourceAndMapVertex.getProducedDataSets().get(0).getResultType())
+ .isEqualTo(ResultPartitionType.PIPELINED_BOUNDED);
}
/** Test setting exchange mode to {@link StreamExchangeMode#HYBRID}. */
@@ -753,7 +761,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(2, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(2);
// it can be chained with Hybrid exchange mode
JobVertex sourceAndMapVertex = verticesSorted.get(0);
@@ -768,7 +776,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.fromElements("test").addSink(new DiscardingSink<>());
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
- assertEquals(JobType.STREAMING, jobGraph.getJobType());
+ Assertions.assertThat(jobGraph.getJobType()).isEqualTo(JobType.STREAMING);
}
@Test
@@ -777,7 +785,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
env.setRuntimeMode(RuntimeExecutionMode.BATCH);
env.fromElements("test").addSink(new DiscardingSink<>());
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
- assertEquals(JobType.BATCH, jobGraph.getJobType());
+ Assertions.assertThat(jobGraph.getJobType()).isEqualTo(JobType.BATCH);
}
@Test
@@ -883,32 +891,40 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
SlotSharingGroup slotSharingGroup = jobGraph.getVerticesAsArray()[0].getSlotSharingGroup();
- assertNotNull(slotSharingGroup);
+ Assertions.assertThat(slotSharingGroup).isNotNull();
CoLocationGroup iterationSourceCoLocationGroup = null;
CoLocationGroup iterationSinkCoLocationGroup = null;
for (JobVertex jobVertex : jobGraph.getVertices()) {
// all vertices have same slot sharing group by default
- assertEquals(slotSharingGroup, jobVertex.getSlotSharingGroup());
+ Assertions.assertThat(jobVertex.getSlotSharingGroup()).isEqualTo(slotSharingGroup);
// all iteration vertices have same co-location group,
// others have no co-location group by default
if (jobVertex.getName().startsWith(StreamGraph.ITERATION_SOURCE_NAME_PREFIX)) {
iterationSourceCoLocationGroup = jobVertex.getCoLocationGroup();
- assertTrue(
- iterationSourceCoLocationGroup.getVertexIds().contains(jobVertex.getID()));
+ Assertions.assertThat(
+ iterationSourceCoLocationGroup
+ .getVertexIds()
+ .contains(jobVertex.getID()))
+ .isTrue();
} else if (jobVertex.getName().startsWith(StreamGraph.ITERATION_SINK_NAME_PREFIX)) {
iterationSinkCoLocationGroup = jobVertex.getCoLocationGroup();
- assertTrue(iterationSinkCoLocationGroup.getVertexIds().contains(jobVertex.getID()));
+ Assertions.assertThat(
+ iterationSinkCoLocationGroup
+ .getVertexIds()
+ .contains(jobVertex.getID()))
+ .isTrue();
} else {
- assertNull(jobVertex.getCoLocationGroup());
+ Assertions.assertThat(jobVertex.getCoLocationGroup()).isNull();
}
}
- assertNotNull(iterationSourceCoLocationGroup);
- assertNotNull(iterationSinkCoLocationGroup);
- assertEquals(iterationSourceCoLocationGroup, iterationSinkCoLocationGroup);
+ Assertions.assertThat(iterationSourceCoLocationGroup).isNotNull();
+ Assertions.assertThat(iterationSinkCoLocationGroup).isNotNull();
+ Assertions.assertThat(iterationSinkCoLocationGroup)
+ .isEqualTo(iterationSourceCoLocationGroup);
}
/** Test default job type. */
@@ -921,7 +937,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
Collections.emptyList(), env.getConfig(), env.getCheckpointConfig())
.generate();
JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
- assertEquals(JobType.STREAMING, jobGraph.getJobType());
+ Assertions.assertThat(jobGraph.getJobType()).isEqualTo(JobType.STREAMING);
}
@Test
@@ -940,8 +956,12 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
streamGraph.getStreamNodes().stream()
.sorted(Comparator.comparingInt(StreamNode::getId))
.collect(Collectors.toList());
- assertTrue(areOperatorsChainable(streamNodes.get(0), streamNodes.get(1), streamGraph));
- assertFalse(areOperatorsChainable(streamNodes.get(1), streamNodes.get(2), streamGraph));
+ Assertions.assertThat(
+ areOperatorsChainable(streamNodes.get(0), streamNodes.get(1), streamGraph))
+ .isTrue();
+ Assertions.assertThat(
+ areOperatorsChainable(streamNodes.get(1), streamNodes.get(2), streamGraph))
+ .isFalse();
}
@Test
@@ -960,8 +980,12 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
streamGraph.getStreamNodes().stream()
.sorted(Comparator.comparingInt(StreamNode::getId))
.collect(Collectors.toList());
- assertFalse(areOperatorsChainable(streamNodes.get(0), streamNodes.get(1), streamGraph));
- assertTrue(areOperatorsChainable(streamNodes.get(1), streamNodes.get(2), streamGraph));
+ Assertions.assertThat(
+ areOperatorsChainable(streamNodes.get(0), streamNodes.get(1), streamGraph))
+ .isFalse();
+ Assertions.assertThat(
+ areOperatorsChainable(streamNodes.get(1), streamNodes.get(2), streamGraph))
+ .isTrue();
}
/**
@@ -984,9 +1008,9 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final JobGraph jobGraph = chainEnv.getStreamGraph().getJobGraph();
final List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
- Assert.assertEquals(2, vertices.size());
- assertEquals(2, vertices.get(0).getOperatorIDs().size());
- assertEquals(5, vertices.get(1).getOperatorIDs().size());
+ Assertions.assertThat(vertices.size()).isEqualTo(2);
+ Assertions.assertThat(vertices.get(0).getOperatorIDs().size()).isEqualTo(2);
+ Assertions.assertThat(vertices.get(1).getOperatorIDs().size()).isEqualTo(5);
}
/**
@@ -1007,8 +1031,8 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final JobGraph jobGraph = chainEnv.getStreamGraph().getJobGraph();
final List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
- Assert.assertEquals(1, vertices.size());
- assertEquals(4, vertices.get(0).getOperatorIDs().size());
+ Assertions.assertThat(vertices.size()).isEqualTo(1);
+ Assertions.assertThat(vertices.get(0).getOperatorIDs().size()).isEqualTo(4);
}
@Test
@@ -1026,12 +1050,16 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
JobGraph jobGraph2 = getUnionJobGraph(env);
JobVertex jobSink2 =
Iterables.getLast(jobGraph2.getVerticesSortedTopologicallyFromSources());
- assertNotEquals("Different runs should yield different vertexes", jobSink, jobSink2);
+ Assertions.assertThat(jobSink)
+ .withFailMessage("Different runs should yield different vertexes")
+ .isNotEqualTo(jobSink2);
List<String> actualSourceOrder =
jobSink2.getInputs().stream()
.map(edge -> edge.getSource().getProducer().getName())
.collect(Collectors.toList());
- assertEquals("Union inputs reordered", expectedSourceOrder, actualSourceOrder);
+ Assertions.assertThat(actualSourceOrder)
+ .withFailMessage("Union inputs reordered")
+ .isEqualTo(expectedSourceOrder);
}
}
@@ -1050,7 +1078,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
return env.fromElements(index).name("source" + index).map(i -> i).name("map" + index);
}
- @Test(expected = UnsupportedOperationException.class)
+ @Test
public void testNotSupportInputSelectableOperatorIfCheckpointing() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.enableCheckpointing(60_000L);
@@ -1064,7 +1092,9 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
new TestAnyModeReadingStreamOperator("test operator"))
.print();
- StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
+ Assertions.assertThatThrownBy(
+ () -> StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph()))
+ .isInstanceOf(UnsupportedOperationException.class);
}
@Test
@@ -1223,25 +1253,25 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
double expectedStateBackendFrac,
Configuration tmConfig) {
final double delta = 0.000001;
- assertEquals(
- expectedStateBackendFrac,
- streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
- ManagedMemoryUseCase.STATE_BACKEND,
- tmConfig,
- ClassLoader.getSystemClassLoader()),
- delta);
- assertEquals(
- expectedPythonFrac,
- streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
- ManagedMemoryUseCase.PYTHON, tmConfig, ClassLoader.getSystemClassLoader()),
- delta);
- assertEquals(
- expectedBatchFrac,
- streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
- ManagedMemoryUseCase.OPERATOR,
- tmConfig,
- ClassLoader.getSystemClassLoader()),
- delta);
+ Assertions.assertThat(
+ streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.STATE_BACKEND,
+ tmConfig,
+ ClassLoader.getSystemClassLoader()))
+ .isCloseTo(expectedStateBackendFrac, Offset.offset(delta));
+ Assertions.assertThat(
+ streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.PYTHON,
+ tmConfig,
+ ClassLoader.getSystemClassLoader()))
+ .isCloseTo(expectedPythonFrac, Offset.offset(delta));
+
+ Assertions.assertThat(
+ streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.OPERATOR,
+ tmConfig,
+ ClassLoader.getSystemClassLoader()))
+ .isCloseTo(expectedBatchFrac, Offset.offset(delta));
}
@Test
@@ -1277,7 +1307,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
final List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(4, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(4);
final List<JobVertex> verticesMatched = getExpectedVerticesList(verticesSorted);
final JobVertex source1Vertex = verticesMatched.get(0);
@@ -1298,7 +1328,7 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph);
final List<JobVertex> verticesSorted = jobGraph.getVerticesSortedTopologicallyFromSources();
- assertEquals(4, verticesSorted.size());
+ Assertions.assertThat(verticesSorted.size()).isEqualTo(4);
final List<JobVertex> verticesMatched = getExpectedVerticesList(verticesSorted);
final JobVertex source1Vertex = verticesMatched.get(0);
@@ -1342,18 +1372,18 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
for (JobVertex jobVertex : jobGraph.getVertices()) {
numVertex += 1;
if (jobVertex.getName().contains(slotSharingGroup1)) {
- assertEquals(
- jobVertex.getSlotSharingGroup().getResourceProfile(), resourceProfile1);
+ Assertions.assertThat(resourceProfile1)
+ .isEqualTo(jobVertex.getSlotSharingGroup().getResourceProfile());
} else if (jobVertex.getName().contains(slotSharingGroup2)) {
- assertEquals(
- jobVertex.getSlotSharingGroup().getResourceProfile(), resourceProfile2);
+ Assertions.assertThat(resourceProfile2)
+ .isEqualTo(jobVertex.getSlotSharingGroup().getResourceProfile());
} else if (jobVertex
.getName()
.contains(StreamGraphGenerator.DEFAULT_SLOT_SHARING_GROUP)) {
- assertEquals(
- jobVertex.getSlotSharingGroup().getResourceProfile(), resourceProfile3);
+ Assertions.assertThat(resourceProfile3)
+ .isEqualTo(jobVertex.getSlotSharingGroup().getResourceProfile());
} else {
- fail();
+ Assertions.fail("");
}
}
assertThat(numVertex, is(3));
@@ -1376,7 +1406,8 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
int numVertex = 0;
for (JobVertex jobVertex : jobGraph.getVertices()) {
numVertex += 1;
- assertEquals(jobVertex.getSlotSharingGroup().getResourceProfile(), resourceProfile);
+ Assertions.assertThat(resourceProfile)
+ .isEqualTo(jobVertex.getSlotSharingGroup().getResourceProfile());
}
assertThat(numVertex, is(2));
}
@@ -1386,9 +1417,8 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
String[] sources = new String[] {"source-1", "source-2", "source-3"};
JobGraph graph = createGraphWithMultipleInputs(true, sources);
JobVertex head = graph.getVerticesSortedTopologicallyFromSources().iterator().next();
- Arrays.stream(sources).forEach(source -> assertTrue(head.getName().contains(source)));
- Arrays.stream(sources)
- .forEach(source -> assertTrue(head.getOperatorPrettyName().contains(source)));
+ Assertions.assertThat(sources)
+ .allMatch(source -> head.getOperatorPrettyName().contains(source));
}
@Test
@@ -1401,9 +1431,12 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
vertex ->
vertex.getInvokableClassName()
.equals(MultipleInputStreamTask.class.getName()));
- assertFalse(head.getName(), head.getName().contains("source-1"));
- assertFalse(
- head.getOperatorPrettyName(), head.getOperatorPrettyName().contains("source-1"));
+ Assertions.assertThat(head.getName().contains("source-1"))
+ .withFailMessage(head.getName())
+ .isFalse();
+ Assertions.assertThat(head.getOperatorPrettyName().contains("source-1"))
+ .withFailMessage(head.getOperatorPrettyName())
+ .isFalse();
}
public JobGraph createGraphWithMultipleInputs(boolean chain, String... inputNames) {
@@ -1433,16 +1466,16 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
JobGraph job = createJobGraphWithDescription(env, "test source");
JobVertex[] allVertices = job.getVerticesAsArray();
- assertEquals(1, allVertices.length);
- assertEquals(
- "test source\n"
- + ":- x + 1\n"
- + ": :- first print of map1\n"
- + ": +- second print of map1\n"
- + "+- x + 2\n"
- + " :- first print of map2\n"
- + " +- second print of map2\n",
- allVertices[0].getOperatorPrettyName());
+ Assertions.assertThat(allVertices.length).isEqualTo(1);
+ Assertions.assertThat(allVertices[0].getOperatorPrettyName())
+ .isEqualTo(
+ "test source\n"
+ + ":- x + 1\n"
+ + ": :- first print of map1\n"
+ + ": +- second print of map1\n"
+ + "+- x + 2\n"
+ + " :- first print of map2\n"
+ + " +- second print of map2\n");
}
@Test
@@ -1450,16 +1483,16 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
JobGraph job = createJobGraphWithDescription(env, "test source 1", "test source 2");
JobVertex[] allVertices = job.getVerticesAsArray();
- assertEquals(1, allVertices.length);
- assertEquals(
- "operator chained with source [test source 1, test source 2]\n"
- + ":- x + 1\n"
- + ": :- first print of map1\n"
- + ": +- second print of map1\n"
- + "+- x + 2\n"
- + " :- first print of map2\n"
- + " +- second print of map2\n",
- allVertices[0].getOperatorPrettyName());
+ Assertions.assertThat(allVertices.length).isEqualTo(1);
+ Assertions.assertThat(allVertices[0].getOperatorPrettyName())
+ .isEqualTo(
+ "operator chained with source [test source 1, test source 2]\n"
+ + ":- x + 1\n"
+ + ": :- first print of map1\n"
+ + ": +- second print of map1\n"
+ + "+- x + 2\n"
+ + " :- first print of map2\n"
+ + " +- second print of map2\n");
}
@Test
@@ -1472,10 +1505,10 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamExecutionEnvironment.getExecutionEnvironment(config);
JobGraph job = createJobGraphWithDescription(env, "test source");
JobVertex[] allVertices = job.getVerticesAsArray();
- assertEquals(1, allVertices.length);
- assertEquals(
- "test source -> (x + 1 -> (first print of map1 , second print of map1) , x + 2 -> (first print of map2 , second print of map2))",
- allVertices[0].getOperatorPrettyName());
+ Assertions.assertThat(allVertices.length).isEqualTo(1);
+ Assertions.assertThat(allVertices[0].getOperatorPrettyName())
+ .isEqualTo(
+ "test source -> (x + 1 -> (first print of map1 , second print of map1) , x + 2 -> (first print of map2 , second print of map2))");
}
@Test
@@ -1488,21 +1521,21 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
StreamExecutionEnvironment.getExecutionEnvironment(config);
JobGraph job = createJobGraphWithDescription(env, "test source 1", "test source 2");
JobVertex[] allVertices = job.getVerticesAsArray();
- assertEquals(1, allVertices.length);
- assertEquals(
- "operator chained with source [test source 1, test source 2] -> (x + 1 -> (first print of map1 , second print of map1) , x + 2 -> (first print of map2 , second print of map2))",
- allVertices[0].getOperatorPrettyName());
+ Assertions.assertThat(allVertices.length).isEqualTo(1);
+ Assertions.assertThat(allVertices[0].getOperatorPrettyName())
+ .isEqualTo(
+ "operator chained with source [test source 1, test source 2] -> (x + 1 -> (first print of map1 , second print of map1) , x + 2 -> (first print of map2 , second print of map2))");
}
@Test
public void testNamingWithoutIndex() {
JobGraph job = createStreamGraphForSlotSharingTest(new Configuration()).getJobGraph();
List<JobVertex> allVertices = job.getVerticesSortedTopologicallyFromSources();
- assertEquals(4, allVertices.size());
- assertEquals("Source: source1", allVertices.get(0).getName());
- assertEquals("Source: source2", allVertices.get(1).getName());
- assertEquals("map1", allVertices.get(2).getName());
- assertEquals("map2", allVertices.get(3).getName());
+ Assertions.assertThat(allVertices.size()).isEqualTo(4);
+ Assertions.assertThat(allVertices.get(0).getName()).isEqualTo("Source: source1");
+ Assertions.assertThat(allVertices.get(1).getName()).isEqualTo("Source: source2");
+ Assertions.assertThat(allVertices.get(2).getName()).isEqualTo("map1");
+ Assertions.assertThat(allVertices.get(3).getName()).isEqualTo("map2");
}
@Test
@@ -1511,11 +1544,66 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
config.setBoolean(PipelineOptions.VERTEX_NAME_INCLUDE_INDEX_PREFIX, true);
JobGraph job = createStreamGraphForSlotSharingTest(config).getJobGraph();
List<JobVertex> allVertices = job.getVerticesSortedTopologicallyFromSources();
- assertEquals(4, allVertices.size());
- assertEquals("[vertex-0]Source: source1", allVertices.get(0).getName());
- assertEquals("[vertex-1]Source: source2", allVertices.get(1).getName());
- assertEquals("[vertex-2]map1", allVertices.get(2).getName());
- assertEquals("[vertex-3]map2", allVertices.get(3).getName());
+ Assertions.assertThat(allVertices.size()).isEqualTo(4);
+ Assertions.assertThat(allVertices.get(0).getName()).isEqualTo("[vertex-0]Source: source1");
+ Assertions.assertThat(allVertices.get(1).getName()).isEqualTo("[vertex-1]Source: source2");
+ Assertions.assertThat(allVertices.get(2).getName()).isEqualTo("[vertex-2]map1");
+ Assertions.assertThat(allVertices.get(3).getName()).isEqualTo("[vertex-3]map2");
+ }
+
+ @Test
+ public void testCacheJobGraph() throws Throwable {
+ final TestingStreamExecutionEnvironment env = new TestingStreamExecutionEnvironment();
+ env.setParallelism(2);
+ env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3).name("source");
+ CachedDataStream<Integer> cachedStream =
+ source.map(i -> i + 1).name("map-1").map(i -> i + 1).name("map-2").cache();
+ Assertions.assertThat(cachedStream.getTransformation())
+ .isInstanceOf(CacheTransformation.class);
+ CacheTransformation<Integer> cacheTransformation =
+ (CacheTransformation<Integer>) cachedStream.getTransformation();
+
+ cachedStream.print().name("print");
+
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+ List<JobVertex> allVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+ Assertions.assertThat(allVertices.size()).isEqualTo(3);
+
+ final JobVertex cacheWriteVertex =
+ allVertices.stream()
+ .filter(jobVertex -> "CacheWrite".equals(jobVertex.getName()))
+ .findFirst()
+ .orElseThrow(
+ (Supplier<Throwable>)
+ () ->
+ new RuntimeException(
+ "CacheWrite job vertex not found"));
+
+ final List<JobEdge> inputs = cacheWriteVertex.getInputs();
+ Assertions.assertThat(inputs.size()).isEqualTo(1);
+ Assertions.assertThat(inputs.get(0).getDistributionPattern()).isEqualTo(POINTWISE);
+ Assertions.assertThat(inputs.get(0).getSource().getResultType())
+ .isEqualTo(ResultPartitionType.BLOCKING_PERSISTENT);
+ Assertions.assertThat(new AbstractID(inputs.get(0).getSourceId()))
+ .isEqualTo(cacheTransformation.getDatasetId());
+ Assertions.assertThat(inputs.get(0).getSource().getProducer().getName())
+ .isEqualTo("map-1 -> map-2 -> Sink: print");
+
+ env.addCompletedClusterDataset(cacheTransformation.getDatasetId());
+ cachedStream.print().name("print");
+
+ jobGraph = env.getStreamGraph().getJobGraph();
+ allVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+ Assertions.assertThat(allVertices.size()).isEqualTo(1);
+ Assertions.assertThat(allVertices.get(0).getName()).isEqualTo("CacheRead -> Sink: print");
+ Assertions.assertThat(allVertices.get(0).getIntermediateDataSetIdsToConsume().size())
+ .isEqualTo(1);
+ Assertions.assertThat(
+ new AbstractID(
+ allVertices.get(0).getIntermediateDataSetIdsToConsume().get(0)))
+ .isEqualTo(cacheTransformation.getDatasetId());
}
private JobGraph createJobGraphWithDescription(
@@ -1610,15 +1698,16 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
private void assertSameSlotSharingGroup(JobVertex... vertices) {
for (int i = 0; i < vertices.length - 1; i++) {
- assertEquals(vertices[i].getSlotSharingGroup(), vertices[i + 1].getSlotSharingGroup());
+ Assertions.assertThat(vertices[i + 1].getSlotSharingGroup())
+ .isEqualTo(vertices[i].getSlotSharingGroup());
}
}
private void assertDistinctSharingGroups(JobVertex... vertices) {
for (int i = 0; i < vertices.length - 1; i++) {
for (int j = i + 1; j < vertices.length; j++) {
- assertNotEquals(
- vertices[i].getSlotSharingGroup(), vertices[j].getSlotSharingGroup());
+ Assertions.assertThat(vertices[i].getSlotSharingGroup())
+ .isNotEqualTo(vertices[j].getSlotSharingGroup());
}
}
}
@@ -1683,4 +1772,17 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
super(environment, transformation);
}
}
+
+ private static class TestingStreamExecutionEnvironment extends StreamExecutionEnvironment {
+ Set<AbstractID> completedClusterDatasetIds = new HashSet<>();
+
+ public void addCompletedClusterDataset(AbstractID id) {
+ completedClusterDatasetIds.add(id);
+ }
+
+ @Override
+ public Set<AbstractID> listCompletedClusterDatasets() {
+ return new HashSet<>(completedClusterDatasetIds);
+ }
+ }
}
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CachedDataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CachedDataStream.scala
new file mode 100644
index 00000000000..fe452f47cb6
--- /dev/null
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CachedDataStream.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.streaming.api.scala
+
+import org.apache.flink.annotation.PublicEvolving
+import org.apache.flink.streaming.api.datastream.{CachedDataStream => JavaCachedDataStream}
+
+@PublicEvolving
+class CachedDataStream[T](javaStream: JavaCachedDataStream[T])
+ extends DataStream[T](javaStream: JavaCachedDataStream[T]) {
+ def invalidate(): Unit = {
+ javaStream.invalidate()
+ }
+}
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
index e0f05a2cb64..c90decba51a 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
@@ -1282,4 +1282,12 @@ class DataStream[T](stream: JavaStream[T]) {
throw new UnsupportedOperationException("Only supported for operators.")
this
}
+
+ @PublicEvolving
+ def cache(): CachedDataStream[T] = stream match {
+ case stream: SingleOutputStreamOperator[T] => new CachedDataStream(stream.cache())
+ case stream: SideOutputDataStream[T] => new CachedDataStream(stream.cache())
+ case _ =>
+ throw new UnsupportedOperationException("Operator " + stream + " cannot cache")
+ }
}
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
index 1bbae558cb6..d8d67dba6a9 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
@@ -49,7 +49,7 @@ import java.net.URI
import scala.collection.JavaConverters._
@Public
-class StreamExecutionEnvironment(javaEnv: JavaEnv) {
+class StreamExecutionEnvironment(javaEnv: JavaEnv) extends AutoCloseable {
/** @return the wrapped Java environment */
def getJavaEnv: JavaEnv = javaEnv
@@ -1034,6 +1034,14 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) {
/** Returns whether Unaligned Checkpoints are force-enabled. */
def isForceUnalignedCheckpoints: Boolean = javaEnv.isForceUnalignedCheckpoints
+
+ /**
+ * Close and clean up the execution environment. All the cached intermediate results will be
+ * released physically.
+ */
+ override def close(): Unit = {
+ javaEnv.close()
+ }
}
object StreamExecutionEnvironment {
diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/StreamingScalaAPICompletenessTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/StreamingScalaAPICompletenessTest.scala
index 35b205f2b74..58863749d2b 100644
--- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/StreamingScalaAPICompletenessTest.scala
+++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/StreamingScalaAPICompletenessTest.scala
@@ -65,6 +65,12 @@ class StreamingScalaAPICompletenessTest extends ScalaAPICompletenessTestBase {
".areExplicitEnvironmentsAllowed",
"org.apache.flink.streaming.api.environment.StreamExecutionEnvironment" +
".registerCollectIterator",
+ "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment" +
+ ".invalidateClusterDataset",
+ "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment" +
+ ".listCompletedClusterDatasets",
+ "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment" +
+ ".registerCacheTransformation",
// TypeHints are only needed for Java API, Scala API doesn't need them
"org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator.returns",
diff --git a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/MiniClusterPipelineExecutorServiceLoader.java b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/MiniClusterPipelineExecutorServiceLoader.java
index b181d51b6cd..c55118202fa 100644
--- a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/MiniClusterPipelineExecutorServiceLoader.java
+++ b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/MiniClusterPipelineExecutorServiceLoader.java
@@ -25,6 +25,7 @@ import org.apache.flink.configuration.ConfigUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.DeploymentOptions;
import org.apache.flink.configuration.PipelineOptions;
+import org.apache.flink.core.execution.CacheSupportedPipelineExecutor;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.core.execution.PipelineExecutor;
import org.apache.flink.core.execution.PipelineExecutorFactory;
@@ -36,6 +37,7 @@ import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterJobClient;
import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.util.AbstractID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -44,6 +46,7 @@ import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Collection;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
@@ -141,7 +144,7 @@ public class MiniClusterPipelineExecutorServiceLoader implements PipelineExecuto
}
}
- private static class MiniClusterExecutor implements PipelineExecutor {
+ private static class MiniClusterExecutor implements CacheSupportedPipelineExecutor {
private final MiniCluster miniCluster;
@@ -169,5 +172,20 @@ public class MiniClusterPipelineExecutorServiceLoader implements PipelineExecuto
userCodeClassLoader,
MiniClusterJobClient.JobFinalizationBehavior.NOTHING));
}
+
+ @Override
+ public CompletableFuture<Set<AbstractID>> listCompletedClusterDatasetIds(
+ Configuration configuration, ClassLoader userCodeClassloader) throws Exception {
+ return miniCluster.listCompletedClusterDatasetIds();
+ }
+
+ @Override
+ public CompletableFuture<Void> invalidateClusterDataset(
+ AbstractID clusterDatasetId,
+ Configuration configuration,
+ ClassLoader userCodeClassloader)
+ throws Exception {
+ return miniCluster.invalidateClusterDataset(clusterDatasetId);
+ }
}
}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/CacheITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/CacheITCase.java
new file mode 100644
index 00000000000..b3cd7298a56
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/CacheITCase.java
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.streaming.runtime;
+
+import org.apache.flink.api.common.RuntimeExecutionMode;
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.serialization.SimpleStringEncoder;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.connector.file.src.reader.TextLineInputFormat;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.minicluster.RpcServiceSharing;
+import org.apache.flink.runtime.scheduler.ClusterDatasetCorruptedException;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.datastream.CachedDataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.streaming.api.transformations.CacheTransformation;
+import org.apache.flink.streaming.util.TestStreamEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.util.AbstractID;
+import org.apache.flink.util.CloseableIterator;
+import org.apache.flink.util.CollectionUtil;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.io.FileUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test datastream cache. */
+public class CacheITCase extends AbstractTestBase {
+
+ private StreamExecutionEnvironment env;
+ private MiniClusterWithClientResource miniClusterWithClientResource;
+
+ @BeforeEach
+ void setUp() throws Exception {
+
+ final Configuration configuration = new Configuration();
+ miniClusterWithClientResource =
+ new MiniClusterWithClientResource(
+ new MiniClusterResourceConfiguration.Builder()
+ .setConfiguration(configuration)
+ .setNumberTaskManagers(1)
+ .setNumberSlotsPerTaskManager(8)
+ .setRpcServiceSharing(RpcServiceSharing.DEDICATED)
+ .withHaLeadershipControl()
+ .build());
+ miniClusterWithClientResource.before();
+
+ env = new TestStreamEnvironment(miniClusterWithClientResource.getMiniCluster(), 8);
+ env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+ }
+
+ @AfterEach
+ void tearDown() {
+ miniClusterWithClientResource.after();
+ }
+
+ @Test
+ void testCacheProduceAndConsume(@TempDir java.nio.file.Path tmpDir) throws Exception {
+ File file = prepareTestData(tmpDir);
+
+ FileSource<String> source =
+ FileSource.forRecordStreamFormat(
+ new TextLineInputFormat(),
+ new org.apache.flink.core.fs.Path(file.getPath()))
+ .build();
+ final CachedDataStream<Integer> cachedDataStream =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(), "source")
+ .map(i -> Integer.parseInt(i) + 1)
+ .cache();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+
+ assertThat(file.delete()).isTrue();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+ }
+
+ @Test
+ void testInvalidateCache(@TempDir java.nio.file.Path tmpDir) throws Exception {
+ File file = prepareTestData(tmpDir);
+
+ FileSource<String> source =
+ FileSource.forRecordStreamFormat(
+ new TextLineInputFormat(),
+ new org.apache.flink.core.fs.Path(file.getPath()))
+ .build();
+
+ final CachedDataStream<Integer> cachedDataStream =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(), "source")
+ .map(i -> Integer.parseInt(i) + 1)
+ .cache();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+
+ assertThat(file.delete()).isTrue();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+
+ cachedDataStream.invalidate();
+
+ // overwrite the content of the source file
+ try (FileWriter writer = new FileWriter(file)) {
+ writer.write("4\n5\n6\n");
+ }
+
+ // after cache is invalidated it should re-read from source file with the updated content
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(5, 6, 7);
+ }
+ }
+
+ @Test
+ void testBatchProduceCacheStreamConsume(@TempDir java.nio.file.Path tmpDir) throws Exception {
+
+ File file = prepareTestData(tmpDir);
+
+ FileSource<String> source =
+ FileSource.forRecordStreamFormat(
+ new TextLineInputFormat(),
+ new org.apache.flink.core.fs.Path(file.getPath()))
+ .build();
+ final CachedDataStream<Integer> cachedDataStream =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(), "source")
+ .map(Integer::parseInt)
+ .map(i -> i + 1)
+ .cache();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+
+ assertThat(file.delete()).isTrue();
+
+ env.setRuntimeMode(RuntimeExecutionMode.STREAMING);
+ try (CloseableIterator<Integer> resultIterator =
+ cachedDataStream.map(i -> i + 1).executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(3, 4, 5);
+ }
+ }
+
+ @Test
+ void testCacheProduceAndConsumeWithDifferentPartitioner() throws Exception {
+
+ final DataStreamSource<Tuple2<Integer, Integer>> ds =
+ env.fromElements(new Tuple2<>(1, 1), new Tuple2<>(2, 1), new Tuple2<>(2, 1));
+
+ final CachedDataStream<Tuple2<Integer, Integer>> cacheSource = ds.cache();
+ SingleOutputStreamOperator<Tuple2<Integer, Integer>> result =
+ cacheSource.keyBy(v -> v.f0).reduce((v1, v2) -> new Tuple2<>(v1.f0, v1.f1 + v2.f1));
+
+ try (CloseableIterator<Tuple2<Integer, Integer>> resultIterator =
+ result.executeAndCollect()) {
+ List<Tuple2<Integer, Integer>> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(new Tuple2<>(1, 1), new Tuple2<>(2, 2));
+ }
+
+ result =
+ cacheSource.keyBy(t -> t.f1).reduce((v1, v2) -> new Tuple2<>(v1.f0 + v2.f0, v1.f1));
+
+ try (CloseableIterator<Tuple2<Integer, Integer>> resultIterator =
+ result.executeAndCollect()) {
+ List<Tuple2<Integer, Integer>> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(new Tuple2<>(5, 1));
+ }
+ }
+
+ @Test
+ void testCacheSideOutput() throws Exception {
+ OutputTag<Integer> tag = new OutputTag<Integer>("2") {};
+ final DataStreamSource<Tuple2<Integer, Integer>> ds =
+ env.fromElements(new Tuple2<>(1, 1), new Tuple2<>(2, 1), new Tuple2<>(2, 2));
+
+ final SingleOutputStreamOperator<Integer> processed =
+ ds.process(
+ new ProcessFunction<Tuple2<Integer, Integer>, Integer>() {
+ @Override
+ public void processElement(
+ Tuple2<Integer, Integer> v,
+ ProcessFunction<Tuple2<Integer, Integer>, Integer>.Context ctx,
+ Collector<Integer> out) {
+ if (v.f0 == 2) {
+ ctx.output(tag, v.f1);
+ return;
+ }
+ out.collect(v.f1);
+ }
+ });
+
+ final CachedDataStream<Integer> cachedSideOutput = processed.getSideOutput(tag).cache();
+
+ try (CloseableIterator<Integer> resultIterator = cachedSideOutput.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(1, 2);
+ }
+
+ try (CloseableIterator<Integer> resultIterator = cachedSideOutput.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(1, 2);
+ }
+ }
+
+ @Test
+ void testRetryOnCorruptedClusterDataset(@TempDir java.nio.file.Path tmpDir) throws Exception {
+ File file = prepareTestData(tmpDir);
+
+ FileSource<String> source =
+ FileSource.forRecordStreamFormat(
+ new TextLineInputFormat(),
+ new org.apache.flink.core.fs.Path(file.getPath()))
+ .build();
+ final CachedDataStream<Integer> cachedDataStream =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(), "source")
+ .map(i -> Integer.parseInt(i) + 1)
+ .cache();
+
+ try (CloseableIterator<Integer> resultIterator = cachedDataStream.executeAndCollect()) {
+ List<Integer> results = CollectionUtil.iteratorToList(resultIterator);
+ assertThat(results).containsExactlyInAnyOrder(2, 3, 4);
+ }
+
+ final AbstractID datasetId =
+ ((CacheTransformation<Integer>) cachedDataStream.getTransformation())
+ .getDatasetId();
+
+ assertThat(file.delete()).isTrue();
+ // overwrite the content of the source file
+ try (FileWriter writer = new FileWriter(file)) {
+ writer.write("4\n5\n6\n");
+ }
+
+ final File outputFile = new File(tmpDir.toFile(), UUID.randomUUID().toString());
+ cachedDataStream
+ .flatMap(
+ (FlatMapFunction<Integer, Integer>)
+ (value, out) -> {
+ if (value < 5) {
+ // Simulate ClusterDatasetCorruptedException.
+ throw new ClusterDatasetCorruptedException(
+ null,
+ Collections.singletonList(
+ new IntermediateDataSetID(datasetId)));
+ }
+ out.collect(value);
+ })
+ .returns(Integer.class)
+ .sinkTo(
+ FileSink.forRowFormat(
+ new org.apache.flink.core.fs.Path(outputFile.getPath()),
+ new SimpleStringEncoder<Integer>())
+ .build());
+ env.execute();
+ assertThat(getFileContext(outputFile)).containsExactlyInAnyOrder("5", "6", "7");
+ }
+
+ private static List<String> getFileContext(File directory) throws IOException {
+ List<String> res = new ArrayList<>();
+
+ final Collection<File> filesInBucket = FileUtils.listFiles(directory, null, true);
+ for (File file : filesInBucket) {
+ res.addAll(Arrays.asList(FileUtils.readFileToString(file).split("\n")));
+ }
+
+ return res;
+ }
+
+ private File prepareTestData(Path tmpDir) throws IOException {
+ final File datafile = new File(tmpDir.toFile(), UUID.randomUUID().toString());
+ try (FileWriter writer = new FileWriter(datafile)) {
+ writer.write("1\n2\n3\n");
+ }
+ return datafile;
+ }
+}