You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/15 06:38:15 UTC
[flink-ml] 01/03: [FLINK-24842][iteration] Make outputs depends on tails for the iteration body
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 2ef6b8702295f246739a02d7335c1a7a1a010a94
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Nov 11 16:04:23 2021 +0800
[FLINK-24842][iteration] Make outputs depends on tails for the iteration body
This closes #31.
---
.../org/apache/flink/iteration/Iterations.java | 63 +++++++++++++++++-----
.../flink/iteration/IterationConstructionTest.java | 37 +++++++------
2 files changed, 73 insertions(+), 27 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
index 2a3fb39..514f31a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration;
import org.apache.flink.annotation.Experimental;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
import org.apache.flink.iteration.operator.HeadOperator;
import org.apache.flink.iteration.operator.HeadOperatorFactory;
@@ -259,22 +260,29 @@ public class Iterations {
tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
}
+ List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
checkState(
mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
"The current iteration type does not support the termination criteria.");
if (iterationBodyResult.getTerminationCriteria() != null) {
- addCriteriaStream(
- iterationBodyResult.getTerminationCriteria(),
- iterationId,
- env,
- draftEnv,
- initVariableStreams,
- headStreams,
- totalInitVariableParallelism);
+ DataStreamList criteriaTails =
+ addCriteriaStream(
+ iterationBodyResult.getTerminationCriteria(),
+ iterationId,
+ env,
+ draftEnv,
+ initVariableStreams,
+ headStreams,
+ totalInitVariableParallelism);
+ tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
}
- return addOutputs(getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv));
+ DataStream<Object> tailsUnion =
+ unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));
+
+ return addOutputs(
+ getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
}
private static DataStreamList addReplayer(
@@ -315,7 +323,7 @@ public class Iterations {
return new DataStreamList(result);
}
- private static void addCriteriaStream(
+ private static DataStreamList addCriteriaStream(
DataStream<?> draftCriteriaStream,
IterationID iterationId,
StreamExecutionEnvironment env,
@@ -364,9 +372,11 @@ public class Iterations {
criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
- // Now we notify all the head operators to count the criteria stream.
+ // Now we notify all the head operators to count the criteria streams.
setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());
+
+ return criteriaTails;
}
@SuppressWarnings({"unchecked", "rawtypes"})
@@ -394,6 +404,24 @@ public class Iterations {
return criteriaDraftEnv.getActualStream(draftMergedStream.getId());
}
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static DataStream<Object> unionAllTails(
+ StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) {
+ return Iterations.<DataStream>map(
+ tailsAndCriteriaTails,
+ tail ->
+ tail.filter(r -> false)
+ .name("filter-tail")
+ .returns(new GenericTypeInfo(Object.class))
+ .setParallelism(
+ tail.getParallelism() > 0
+ ? tail.getParallelism()
+ : env.getConfig().getParallelism()))
+ .stream()
+ .reduce(DataStream::union)
+ .get();
+ }
+
private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
return map(dataStreams, DataStream::getType);
}
@@ -453,7 +481,8 @@ public class Iterations {
.setParallelism(dataStream.getParallelism())));
}
- private static DataStreamList addOutputs(DataStreamList dataStreams) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) {
return new DataStreamList(
map(
dataStreams,
@@ -461,6 +490,16 @@ public class Iterations {
IterationRecordTypeInfo<?> inputType =
(IterationRecordTypeInfo<?>) dataStream.getType();
return dataStream
+ .union(
+ tailsUnion
+ .map(x -> x)
+ .name(
+ "tail-map-"
+ + dataStream
+ .getTransformation()
+ .getName())
+ .returns(inputType)
+ .setParallelism(1))
.transform(
"output-" + dataStream.getTransformation().getName(),
inputType.getInnerTypeInfo(),
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index f2ec465..5844b5a 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -63,7 +63,7 @@ public class IterationConstructionTest extends TestLogger {
Arrays.asList(
/* 0 */ "Source: Variable -> input-Variable",
/* 1 */ "head-Variable",
- /* 2 */ "tail-head-Variable");
+ /* 2 */ "tail-head-Variable -> filter-tail");
List<Integer> expectedParallelisms = Arrays.asList(4, 4, 4);
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -102,7 +102,7 @@ public class IterationConstructionTest extends TestLogger {
/* 0 */ "Source: Variable",
/* 1 */ "map -> input-map",
/* 2 */ "head-map",
- /* 3 */ "tail-head-map");
+ /* 3 */ "tail-head-map -> filter-tail");
List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2);
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -191,12 +191,14 @@ public class IterationConstructionTest extends TestLogger {
/* 2 */ "Source: Constant -> input-Constant",
/* 3 */ "head-Variable0",
/* 4 */ "head-Variable1",
- /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
+ /* 5 */ "Processor",
/* 6 */ "Feedback0",
- /* 7 */ "tail-Feedback0",
+ /* 7 */ "tail-Feedback0 -> filter-tail",
/* 8 */ "Feedback1",
- /* 9 */ "tail-Feedback1");
- List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3);
+ /* 9 */ "tail-Feedback1 -> filter-tail",
+ /* 10 */ "tail-map-SideOutput",
+ /* 11 */ "output-SideOutput -> Sink: Sink");
+ List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3, 1, 4);
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -286,17 +288,19 @@ public class IterationConstructionTest extends TestLogger {
/* 3 */ "Source: Termination -> input-Termination",
/* 4 */ "head-Variable0",
/* 5 */ "head-Variable1",
- /* 6 */ "Processor -> output-SideOutput -> Sink: Sink",
+ /* 6 */ "Processor",
/* 7 */ "Feedback0",
- /* 8 */ "tail-Feedback0",
+ /* 8 */ "tail-Feedback0 -> filter-tail",
/* 9 */ "Feedback1",
- /* 10 */ "tail-Feedback1",
+ /* 10 */ "tail-Feedback1 -> filter-tail",
/* 11 */ "Termination",
/* 12 */ "head-Termination",
/* 13 */ "criteria-merge",
- /* 14 */ "tail-criteria-merge");
+ /* 14 */ "tail-criteria-merge -> filter-tail",
+ /* 15 */ "tail-map-SideOutput",
+ /* 16 */ "output-SideOutput -> Sink: Sink");
List<Integer> expectedParallelisms =
- Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5);
+ Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5, 1, 4);
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -380,14 +384,17 @@ public class IterationConstructionTest extends TestLogger {
/* 2 */ "Source: Termination -> input-Termination",
/* 3 */ "head-Variable",
/* 4 */ "Replayer-Constant",
- /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
+ /* 5 */ "Processor",
/* 6 */ "Feedback",
- /* 7 */ "tail-Feedback",
+ /* 7 */ "tail-Feedback -> filter-tail",
/* 8 */ "Termination",
/* 9 */ "head-Termination",
/* 10 */ "criteria-merge",
- /* 11 */ "tail-criteria-merge");
- List<Integer> expectedParallelisms = Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5);
+ /* 11 */ "tail-criteria-merge -> filter-tail",
+ /* 12 */ "tail-map-SideOutput",
+ /* 13 */ "output-SideOutput -> Sink: Sink");
+ List<Integer> expectedParallelisms =
+ Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5, 1, 4);
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();