You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/10 03:43:12 UTC

[flink-ml] branch master updated: [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new a368ebb  [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body
a368ebb is described below

commit a368ebb17affae872f0ea9eb7bb9576fb56612ee
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Tue Nov 2 17:06:12 2021 +0800

    [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body
    
    This closes #22.
---
 .../flink/iteration/operator/OperatorUtils.java    | 23 +++++++++
 .../allround/AbstractAllRoundWrapperOperator.java  |  4 +-
 .../allround/OneInputAllRoundWrapperOperator.java  |  6 ++-
 .../perround/AbstractPerRoundWrapperOperator.java  | 10 ++--
 .../flink/iteration/proxy/ProxyKeySelector.java    |  4 ++
 .../iteration/proxy/ProxyStreamPartitioner.java    | 11 +++++
 .../BoundedAllRoundStreamIterationITCase.java      | 31 ++++++++++--
 .../BoundedPerRoundStreamIterationITCase.java      | 14 +++++-
 .../iteration/UnboundedStreamIterationITCase.java  | 28 +++++++++--
 .../operators/StatefulProcessFunction.java         | 55 ++++++++++++++++++++++
 10 files changed, 171 insertions(+), 15 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 292a90e..25d200b 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -18,15 +18,18 @@
 
 package org.apache.flink.iteration.operator;
 
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.utils.ReflectionUtils;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
+import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.util.ExceptionUtils;
@@ -39,6 +42,8 @@ import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.Executor;
 
+import static org.apache.flink.util.Preconditions.checkState;
+
 /** Utility class for operators. */
 public class OperatorUtils {
 
@@ -83,6 +88,24 @@ public class OperatorUtils {
         }
     }
 
+    public static StreamConfig createWrappedOperatorConfig(StreamConfig wrapperConfig) {
+        StreamConfig wrappedConfig = new StreamConfig(wrapperConfig.getConfiguration().clone());
+        for (int i = 0; i < wrappedConfig.getNumberOfNetworkInputs(); ++i) {
+            KeySelector keySelector =
+                    wrapperConfig.getStatePartitioner(i, OperatorUtils.class.getClassLoader());
+            if (keySelector != null) {
+                checkState(
+                        keySelector instanceof ProxyKeySelector,
+                        "The state partitioner for the wrapper operator should always be ProxyKeySelector, but it is "
+                                + keySelector);
+                wrappedConfig.setStatePartitioner(
+                        i, ((ProxyKeySelector) keySelector).getWrappedKeySelector());
+            }
+        }
+
+        return wrappedConfig;
+    }
+
     public static Path getDataCachePath(Configuration configuration, String[] localSpillPaths) {
         String pathStr = configuration.get(IterationOptions.DATA_CACHE_PATH);
         if (pathStr == null) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 0ea742c..180477c 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -29,6 +29,7 @@ import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.AbstractWrapperOperator;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
@@ -84,7 +85,8 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
                         StreamOperatorFactoryUtil.<T, S>createOperator(
                                         operatorFactory,
                                         (StreamTask) parameters.getContainingTask(),
-                                        parameters.getStreamConfig(),
+                                        OperatorUtils.createWrappedOperatorConfig(
+                                                parameters.getStreamConfig()),
                                         proxyOutput,
                                         parameters.getOperatorEventDispatcher())
                                 .f0;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
index 6a2b9a0..7c725d9 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
@@ -78,8 +78,10 @@ public class OneInputAllRoundWrapperOperator<IN, OUT>
 
     @Override
     public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record) throws Exception {
-        reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
-        wrappedOperator.setKeyContextElement(reusedInput);
+        if (record.getValue().getType() == IterationRecord.Type.RECORD) {
+            reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
+            wrappedOperator.setKeyContextElement(reusedInput);
+        }
     }
 
     @Override
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
index 3903340..cc4ac36 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -27,6 +27,7 @@ import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
 import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
 import org.apache.flink.iteration.utils.ReflectionUtils;
@@ -124,7 +125,8 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
                             StreamOperatorFactoryUtil.<T, S>createOperator(
                                             clonedOperatorFactory,
                                             (StreamTask) parameters.getContainingTask(),
-                                            parameters.getStreamConfig(),
+                                            OperatorUtils.createWrappedOperatorConfig(
+                                                    parameters.getStreamConfig()),
                                             proxyOutput,
                                             parameters.getOperatorEventDispatcher())
                                     .f0;
@@ -294,7 +296,9 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
 
     private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
             throws Exception {
-        if (selector != null) {
+        if (selector != null
+                && ((IterationRecord<?>) record.getValue()).getType()
+                        == IterationRecord.Type.RECORD) {
             Object key = selector.getKey(record.getValue());
             setCurrentKey(key);
         }
@@ -335,7 +339,7 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
             return null;
         }
 
-        return stateHandler.getKeyedStateStore().orElse(null);
+        return stateHandler.getCurrentKey();
     }
 
     protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
index 1ac64ec..f3615f7 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
@@ -30,6 +30,10 @@ public class ProxyKeySelector<T, KEY> implements KeySelector<IterationRecord<T>,
         this.wrappedKeySelector = wrappedKeySelector;
     }
 
+    public KeySelector<T, KEY> getWrappedKeySelector() {
+        return wrappedKeySelector;
+    }
+
     @Override
     public KEY getKey(IterationRecord<T> record) throws Exception {
         return wrappedKeySelector.getKey(record.getValue());
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
index 4accb32..525f12a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
@@ -44,6 +44,12 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord
     }
 
     @Override
+    public void setup(int numberOfChannels) {
+        super.setup(numberOfChannels);
+        wrappedStreamPartitioner.setup(numberOfChannels);
+    }
+
+    @Override
     public StreamPartitioner<IterationRecord<T>> copy() {
         return new ProxyStreamPartitioner<>(wrappedStreamPartitioner.copy());
     }
@@ -87,4 +93,9 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord
             return selectChannel(record);
         }
     }
+
+    @Override
+    public String toString() {
+        return wrappedStreamPartitioner.toString();
+    }
 }
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
index 5084c78..1b28374 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
@@ -37,6 +37,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap;
 import org.apache.flink.test.iteration.operators.OutputRecord;
 import org.apache.flink.test.iteration.operators.RoundBasedTerminationCriteria;
 import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
 import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
 import org.apache.flink.testutils.junit.SharedObjects;
 import org.apache.flink.testutils.junit.SharedReference;
@@ -130,6 +131,8 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
         // If termination criteria is created only with the constants streams, it would not have
         // records after the round 1 if the input is not replayed.
         int numOfRound = terminationCriteriaFollowsConstantsStreams ? 1 : 5;
+        assertEquals(numOfRound + 1, result.get().size());
+
         Map<Integer, Tuple2<Integer, Integer>> roundsStat =
                 computeRoundStat(
                         result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, numOfRound);
@@ -184,9 +187,19 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
                                             .process(
                                                     new TwoInputReduceAllRoundProcessFunction(
                                                             sync, maxRound));
+
                             return new IterationBodyResult(
                                     DataStreamList.of(
-                                            reducer.map(new IncrementEpochMap())
+                                            reducer.partitionCustom(
+                                                            (k, numPartitions) -> k % numPartitions,
+                                                            EpochRecord::getValue)
+                                                    .map(x -> x)
+                                                    .keyBy(EpochRecord::getValue)
+                                                    .process(
+                                                            new StatefulProcessFunction<
+                                                                    EpochRecord>() {})
+                                                    .setParallelism(4)
+                                                    .map(new IncrementEpochMap())
                                                     .setParallelism(numSources)),
                                     DataStreamList.of(
                                             reducer.getSideOutput(
@@ -237,10 +250,20 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
                                             .process(
                                                     new TwoInputReduceAllRoundProcessFunction(
                                                             sync, maxRound));
+
+                            SingleOutputStreamOperator<EpochRecord> feedbackStream =
+                                    reducer.partitionCustom(
+                                                    (k, numPartitions) -> k % numPartitions,
+                                                    EpochRecord::getValue)
+                                            .map(x -> x)
+                                            .keyBy(EpochRecord::getValue)
+                                            .process(new StatefulProcessFunction<EpochRecord>() {})
+                                            .setParallelism(4)
+                                            .map(new IncrementEpochMap())
+                                            .setParallelism(numSources);
+
                             return new IterationBodyResult(
-                                    DataStreamList.of(
-                                            reducer.map(new IncrementEpochMap())
-                                                    .setParallelism(numSources)),
+                                    DataStreamList.of(feedbackStream),
                                     DataStreamList.of(
                                             reducer.getSideOutput(
                                                     new OutputTag<OutputRecord<Integer>>(
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index cb36d0c..8bc6f1f 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -34,6 +34,7 @@ import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
 import org.apache.flink.test.iteration.operators.OutputRecord;
 import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
 import org.apache.flink.test.iteration.operators.TwoInputReducePerRoundOperator;
 import org.apache.flink.testutils.junit.SharedObjects;
 import org.apache.flink.testutils.junit.SharedReference;
@@ -123,7 +124,18 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger {
                                             .setParallelism(1);
 
                             return new IterationBodyResult(
-                                    DataStreamList.of(reducer.filter(x -> x < maxRound)),
+                                    DataStreamList.of(
+                                            reducer.partitionCustom(
+                                                            (k, numPartitions) -> k % numPartitions,
+                                                            x -> x)
+                                                    .map(x -> x)
+                                                    .keyBy(x -> x)
+                                                    .process(
+                                                            new StatefulProcessFunction<
+                                                                    Integer>() {})
+                                                    .setParallelism(4)
+                                                    .filter(x -> x < maxRound)
+                                                    .setParallelism(1)),
                                     DataStreamList.of(
                                             reducer.getSideOutput(
                                                     TwoInputReducePerRoundOperator.OUTPUT_TAG)),
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index 6d80f23..f3f2272 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -38,6 +38,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap;
 import org.apache.flink.test.iteration.operators.OutputRecord;
 import org.apache.flink.test.iteration.operators.ReduceAllRoundProcessFunction;
 import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
 import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
 import org.apache.flink.testutils.junit.SharedObjects;
 import org.apache.flink.testutils.junit.SharedReference;
@@ -192,7 +193,16 @@ public class UnboundedStreamIterationITCase extends TestLogger {
                                             new ReduceAllRoundProcessFunction(sync, maxRound));
                             return new IterationBodyResult(
                                     DataStreamList.of(
-                                            reducer.map(new IncrementEpochMap())
+                                            reducer.partitionCustom(
+                                                            (k, numPartitions) -> k % numPartitions,
+                                                            EpochRecord::getValue)
+                                                    .map(x -> x)
+                                                    .keyBy(EpochRecord::getValue)
+                                                    .process(
+                                                            new StatefulProcessFunction<
+                                                                    EpochRecord>() {})
+                                                    .setParallelism(4)
+                                                    .map(new IncrementEpochMap())
                                                     .setParallelism(numSources)),
                                     DataStreamList.of(
                                             reducer.getSideOutput(
@@ -234,10 +244,20 @@ public class UnboundedStreamIterationITCase extends TestLogger {
                                             .process(
                                                     new TwoInputReduceAllRoundProcessFunction(
                                                             sync, maxRound));
+
+                            SingleOutputStreamOperator<EpochRecord> feedbackStream =
+                                    reducer.partitionCustom(
+                                                    (k, numPartitions) -> k % numPartitions,
+                                                    EpochRecord::getValue)
+                                            .map(x -> x)
+                                            .keyBy(EpochRecord::getValue)
+                                            .process(new StatefulProcessFunction<EpochRecord>() {})
+                                            .setParallelism(4)
+                                            .map(new IncrementEpochMap())
+                                            .setParallelism(numSources);
+
                             return new IterationBodyResult(
-                                    DataStreamList.of(
-                                            reducer.map(new IncrementEpochMap())
-                                                    .setParallelism(numSources)),
+                                    DataStreamList.of(feedbackStream),
                                     DataStreamList.of(
                                             reducer.getSideOutput(
                                                     new OutputTag<OutputRecord<Integer>>(
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java
new file mode 100644
index 0000000..47415f5
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration.operators;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.util.Collector;
+
+/**
+ * This is a function that uses keyed state so that we could verify the correctness of using keyed
+ * stream inside the iteration.
+ */
+public class StatefulProcessFunction<T> extends KeyedProcessFunction<Integer, T, T> {
+
+    private ValueState<Integer> state;
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        this.state =
+                getRuntimeContext().getState(new ValueStateDescriptor<>("state", Integer.class));
+    }
+
+    @Override
+    public void processElement(T value, Context ctx, Collector<T> out) throws Exception {
+        if (state.value() == null) {
+            state.update(0);
+
+            // Trying registers a timer
+            ctx.timerService().registerEventTimeTimer(1000L);
+        } else {
+            state.update(state.value() + 1);
+        }
+
+        out.collect(value);
+    }
+}