You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/15 06:38:15 UTC

[flink-ml] 01/03: [FLINK-24842][iteration] Make outputs depends on tails for the iteration body

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 2ef6b8702295f246739a02d7335c1a7a1a010a94
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Nov 11 16:04:23 2021 +0800

    [FLINK-24842][iteration] Make outputs depends on tails for the iteration body
    
    This closes #31.
---
 .../org/apache/flink/iteration/Iterations.java     | 63 +++++++++++++++++-----
 .../flink/iteration/IterationConstructionTest.java | 37 +++++++------
 2 files changed, 73 insertions(+), 27 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
index 2a3fb39..514f31a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration;
 
 import org.apache.flink.annotation.Experimental;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
 import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.iteration.operator.HeadOperator;
 import org.apache.flink.iteration.operator.HeadOperatorFactory;
@@ -259,22 +260,29 @@ public class Iterations {
             tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
         }
 
+        List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
         checkState(
                 mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
                 "The current iteration type does not support the termination criteria.");
 
         if (iterationBodyResult.getTerminationCriteria() != null) {
-            addCriteriaStream(
-                    iterationBodyResult.getTerminationCriteria(),
-                    iterationId,
-                    env,
-                    draftEnv,
-                    initVariableStreams,
-                    headStreams,
-                    totalInitVariableParallelism);
+            DataStreamList criteriaTails =
+                    addCriteriaStream(
+                            iterationBodyResult.getTerminationCriteria(),
+                            iterationId,
+                            env,
+                            draftEnv,
+                            initVariableStreams,
+                            headStreams,
+                            totalInitVariableParallelism);
+            tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
         }
 
-        return addOutputs(getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv));
+        DataStream<Object> tailsUnion =
+                unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));
+
+        return addOutputs(
+                getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
     }
 
     private static DataStreamList addReplayer(
@@ -315,7 +323,7 @@ public class Iterations {
         return new DataStreamList(result);
     }
 
-    private static void addCriteriaStream(
+    private static DataStreamList addCriteriaStream(
             DataStream<?> draftCriteriaStream,
             IterationID iterationId,
             StreamExecutionEnvironment env,
@@ -364,9 +372,11 @@ public class Iterations {
         criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
         criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
 
-        // Now we notify all the head operators to count the criteria stream.
+        // Now we notify all the head operators to count the criteria streams.
         setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
         setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());
+
+        return criteriaTails;
     }
 
     @SuppressWarnings({"unchecked", "rawtypes"})
@@ -394,6 +404,24 @@ public class Iterations {
         return criteriaDraftEnv.getActualStream(draftMergedStream.getId());
     }
 
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static DataStream<Object> unionAllTails(
+            StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) {
+        return Iterations.<DataStream>map(
+                        tailsAndCriteriaTails,
+                        tail ->
+                                tail.filter(r -> false)
+                                        .name("filter-tail")
+                                        .returns(new GenericTypeInfo(Object.class))
+                                        .setParallelism(
+                                                tail.getParallelism() > 0
+                                                        ? tail.getParallelism()
+                                                        : env.getConfig().getParallelism()))
+                .stream()
+                .reduce(DataStream::union)
+                .get();
+    }
+
     private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
         return map(dataStreams, DataStream::getType);
     }
@@ -453,7 +481,8 @@ public class Iterations {
                                         .setParallelism(dataStream.getParallelism())));
     }
 
-    private static DataStreamList addOutputs(DataStreamList dataStreams) {
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) {
         return new DataStreamList(
                 map(
                         dataStreams,
@@ -461,6 +490,16 @@ public class Iterations {
                             IterationRecordTypeInfo<?> inputType =
                                     (IterationRecordTypeInfo<?>) dataStream.getType();
                             return dataStream
+                                    .union(
+                                            tailsUnion
+                                                    .map(x -> x)
+                                                    .name(
+                                                            "tail-map-"
+                                                                    + dataStream
+                                                                            .getTransformation()
+                                                                            .getName())
+                                                    .returns(inputType)
+                                                    .setParallelism(1))
                                     .transform(
                                             "output-" + dataStream.getTransformation().getName(),
                                             inputType.getInnerTypeInfo(),
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index f2ec465..5844b5a 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -63,7 +63,7 @@ public class IterationConstructionTest extends TestLogger {
                 Arrays.asList(
                         /* 0 */ "Source: Variable -> input-Variable",
                         /* 1 */ "head-Variable",
-                        /* 2 */ "tail-head-Variable");
+                        /* 2 */ "tail-head-Variable -> filter-tail");
         List<Integer> expectedParallelisms = Arrays.asList(4, 4, 4);
 
         List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -102,7 +102,7 @@ public class IterationConstructionTest extends TestLogger {
                         /* 0 */ "Source: Variable",
                         /* 1 */ "map -> input-map",
                         /* 2 */ "head-map",
-                        /* 3 */ "tail-head-map");
+                        /* 3 */ "tail-head-map -> filter-tail");
         List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2);
 
         List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -191,12 +191,14 @@ public class IterationConstructionTest extends TestLogger {
                         /* 2 */ "Source: Constant -> input-Constant",
                         /* 3 */ "head-Variable0",
                         /* 4 */ "head-Variable1",
-                        /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
+                        /* 5 */ "Processor",
                         /* 6 */ "Feedback0",
-                        /* 7 */ "tail-Feedback0",
+                        /* 7 */ "tail-Feedback0 -> filter-tail",
                         /* 8 */ "Feedback1",
-                        /* 9 */ "tail-Feedback1");
-        List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3);
+                        /* 9 */ "tail-Feedback1 -> filter-tail",
+                        /* 10 */ "tail-map-SideOutput",
+                        /* 11 */ "output-SideOutput -> Sink: Sink");
+        List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3, 1, 4);
 
         JobGraph jobGraph = env.getStreamGraph().getJobGraph();
         List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -286,17 +288,19 @@ public class IterationConstructionTest extends TestLogger {
                         /* 3 */ "Source: Termination -> input-Termination",
                         /* 4 */ "head-Variable0",
                         /* 5 */ "head-Variable1",
-                        /* 6 */ "Processor -> output-SideOutput -> Sink: Sink",
+                        /* 6 */ "Processor",
                         /* 7 */ "Feedback0",
-                        /* 8 */ "tail-Feedback0",
+                        /* 8 */ "tail-Feedback0 -> filter-tail",
                         /* 9 */ "Feedback1",
-                        /* 10 */ "tail-Feedback1",
+                        /* 10 */ "tail-Feedback1 -> filter-tail",
                         /* 11 */ "Termination",
                         /* 12 */ "head-Termination",
                         /* 13 */ "criteria-merge",
-                        /* 14 */ "tail-criteria-merge");
+                        /* 14 */ "tail-criteria-merge -> filter-tail",
+                        /* 15 */ "tail-map-SideOutput",
+                        /* 16 */ "output-SideOutput -> Sink: Sink");
         List<Integer> expectedParallelisms =
-                Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5);
+                Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5, 1, 4);
 
         JobGraph jobGraph = env.getStreamGraph().getJobGraph();
         List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
@@ -380,14 +384,17 @@ public class IterationConstructionTest extends TestLogger {
                         /* 2 */ "Source: Termination -> input-Termination",
                         /* 3 */ "head-Variable",
                         /* 4 */ "Replayer-Constant",
-                        /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
+                        /* 5 */ "Processor",
                         /* 6 */ "Feedback",
-                        /* 7 */ "tail-Feedback",
+                        /* 7 */ "tail-Feedback -> filter-tail",
                         /* 8 */ "Termination",
                         /* 9 */ "head-Termination",
                         /* 10 */ "criteria-merge",
-                        /* 11 */ "tail-criteria-merge");
-        List<Integer> expectedParallelisms = Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5);
+                        /* 11 */ "tail-criteria-merge -> filter-tail",
+                        /* 12 */ "tail-map-SideOutput",
+                        /* 13 */ "output-SideOutput -> Sink: Sink");
+        List<Integer> expectedParallelisms =
+                Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5, 1, 4);
 
         JobGraph jobGraph = env.getStreamGraph().getJobGraph();
         List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();