You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/10 03:43:12 UTC
[flink-ml] branch master updated: [FLINK-24722][iteration] Fix the
issues in supporting keyed stream inside the iteration body
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new a368ebb [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body
a368ebb is described below
commit a368ebb17affae872f0ea9eb7bb9576fb56612ee
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Tue Nov 2 17:06:12 2021 +0800
[FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body
This closes #22.
---
.../flink/iteration/operator/OperatorUtils.java | 23 +++++++++
.../allround/AbstractAllRoundWrapperOperator.java | 4 +-
.../allround/OneInputAllRoundWrapperOperator.java | 6 ++-
.../perround/AbstractPerRoundWrapperOperator.java | 10 ++--
.../flink/iteration/proxy/ProxyKeySelector.java | 4 ++
.../iteration/proxy/ProxyStreamPartitioner.java | 11 +++++
.../BoundedAllRoundStreamIterationITCase.java | 31 ++++++++++--
.../BoundedPerRoundStreamIterationITCase.java | 14 +++++-
.../iteration/UnboundedStreamIterationITCase.java | 28 +++++++++--
.../operators/StatefulProcessFunction.java | 55 ++++++++++++++++++++++
10 files changed, 171 insertions(+), 15 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 292a90e..25d200b 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -18,15 +18,18 @@
package org.apache.flink.iteration.operator;
+import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
+import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.util.ExceptionUtils;
@@ -39,6 +42,8 @@ import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executor;
+import static org.apache.flink.util.Preconditions.checkState;
+
/** Utility class for operators. */
public class OperatorUtils {
@@ -83,6 +88,24 @@ public class OperatorUtils {
}
}
+ public static StreamConfig createWrappedOperatorConfig(StreamConfig wrapperConfig) {
+ StreamConfig wrappedConfig = new StreamConfig(wrapperConfig.getConfiguration().clone());
+ for (int i = 0; i < wrappedConfig.getNumberOfNetworkInputs(); ++i) {
+ KeySelector keySelector =
+ wrapperConfig.getStatePartitioner(i, OperatorUtils.class.getClassLoader());
+ if (keySelector != null) {
+ checkState(
+ keySelector instanceof ProxyKeySelector,
+ "The state partitioner for the wrapper operator should always be ProxyKeySelector, but it is "
+ + keySelector);
+ wrappedConfig.setStatePartitioner(
+ i, ((ProxyKeySelector) keySelector).getWrappedKeySelector());
+ }
+ }
+
+ return wrappedConfig;
+ }
+
public static Path getDataCachePath(Configuration configuration, String[] localSpillPaths) {
String pathStr = configuration.get(IterationOptions.DATA_CACHE_PATH);
if (pathStr == null) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 0ea742c..180477c 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -29,6 +29,7 @@ import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
@@ -84,7 +85,8 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
StreamOperatorFactoryUtil.<T, S>createOperator(
operatorFactory,
(StreamTask) parameters.getContainingTask(),
- parameters.getStreamConfig(),
+ OperatorUtils.createWrappedOperatorConfig(
+ parameters.getStreamConfig()),
proxyOutput,
parameters.getOperatorEventDispatcher())
.f0;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
index 6a2b9a0..7c725d9 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
@@ -78,8 +78,10 @@ public class OneInputAllRoundWrapperOperator<IN, OUT>
@Override
public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record) throws Exception {
- reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
- wrappedOperator.setKeyContextElement(reusedInput);
+ if (record.getValue().getType() == IterationRecord.Type.RECORD) {
+ reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
+ wrappedOperator.setKeyContextElement(reusedInput);
+ }
}
@Override
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
index 3903340..cc4ac36 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -27,6 +27,7 @@ import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
import org.apache.flink.iteration.utils.ReflectionUtils;
@@ -124,7 +125,8 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
StreamOperatorFactoryUtil.<T, S>createOperator(
clonedOperatorFactory,
(StreamTask) parameters.getContainingTask(),
- parameters.getStreamConfig(),
+ OperatorUtils.createWrappedOperatorConfig(
+ parameters.getStreamConfig()),
proxyOutput,
parameters.getOperatorEventDispatcher())
.f0;
@@ -294,7 +296,9 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
throws Exception {
- if (selector != null) {
+ if (selector != null
+ && ((IterationRecord<?>) record.getValue()).getType()
+ == IterationRecord.Type.RECORD) {
Object key = selector.getKey(record.getValue());
setCurrentKey(key);
}
@@ -335,7 +339,7 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
return null;
}
- return stateHandler.getKeyedStateStore().orElse(null);
+ return stateHandler.getCurrentKey();
}
protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
index 1ac64ec..f3615f7 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java
@@ -30,6 +30,10 @@ public class ProxyKeySelector<T, KEY> implements KeySelector<IterationRecord<T>,
this.wrappedKeySelector = wrappedKeySelector;
}
+ public KeySelector<T, KEY> getWrappedKeySelector() {
+ return wrappedKeySelector;
+ }
+
@Override
public KEY getKey(IterationRecord<T> record) throws Exception {
return wrappedKeySelector.getKey(record.getValue());
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
index 4accb32..525f12a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java
@@ -44,6 +44,12 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord
}
@Override
+ public void setup(int numberOfChannels) {
+ super.setup(numberOfChannels);
+ wrappedStreamPartitioner.setup(numberOfChannels);
+ }
+
+ @Override
public StreamPartitioner<IterationRecord<T>> copy() {
return new ProxyStreamPartitioner<>(wrappedStreamPartitioner.copy());
}
@@ -87,4 +93,9 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord
return selectChannel(record);
}
}
+
+ @Override
+ public String toString() {
+ return wrappedStreamPartitioner.toString();
+ }
}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
index 5084c78..1b28374 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
@@ -37,6 +37,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap;
import org.apache.flink.test.iteration.operators.OutputRecord;
import org.apache.flink.test.iteration.operators.RoundBasedTerminationCriteria;
import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
@@ -130,6 +131,8 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
// If termination criteria is created only with the constants streams, it would not have
// records after the round 1 if the input is not replayed.
int numOfRound = terminationCriteriaFollowsConstantsStreams ? 1 : 5;
+ assertEquals(numOfRound + 1, result.get().size());
+
Map<Integer, Tuple2<Integer, Integer>> roundsStat =
computeRoundStat(
result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, numOfRound);
@@ -184,9 +187,19 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
.process(
new TwoInputReduceAllRoundProcessFunction(
sync, maxRound));
+
return new IterationBodyResult(
DataStreamList.of(
- reducer.map(new IncrementEpochMap())
+ reducer.partitionCustom(
+ (k, numPartitions) -> k % numPartitions,
+ EpochRecord::getValue)
+ .map(x -> x)
+ .keyBy(EpochRecord::getValue)
+ .process(
+ new StatefulProcessFunction<
+ EpochRecord>() {})
+ .setParallelism(4)
+ .map(new IncrementEpochMap())
.setParallelism(numSources)),
DataStreamList.of(
reducer.getSideOutput(
@@ -237,10 +250,20 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger {
.process(
new TwoInputReduceAllRoundProcessFunction(
sync, maxRound));
+
+ SingleOutputStreamOperator<EpochRecord> feedbackStream =
+ reducer.partitionCustom(
+ (k, numPartitions) -> k % numPartitions,
+ EpochRecord::getValue)
+ .map(x -> x)
+ .keyBy(EpochRecord::getValue)
+ .process(new StatefulProcessFunction<EpochRecord>() {})
+ .setParallelism(4)
+ .map(new IncrementEpochMap())
+ .setParallelism(numSources);
+
return new IterationBodyResult(
- DataStreamList.of(
- reducer.map(new IncrementEpochMap())
- .setParallelism(numSources)),
+ DataStreamList.of(feedbackStream),
DataStreamList.of(
reducer.getSideOutput(
new OutputTag<OutputRecord<Integer>>(
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index cb36d0c..8bc6f1f 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -34,6 +34,7 @@ import org.apache.flink.test.iteration.operators.CollectSink;
import org.apache.flink.test.iteration.operators.EpochRecord;
import org.apache.flink.test.iteration.operators.OutputRecord;
import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
import org.apache.flink.test.iteration.operators.TwoInputReducePerRoundOperator;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
@@ -123,7 +124,18 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger {
.setParallelism(1);
return new IterationBodyResult(
- DataStreamList.of(reducer.filter(x -> x < maxRound)),
+ DataStreamList.of(
+ reducer.partitionCustom(
+ (k, numPartitions) -> k % numPartitions,
+ x -> x)
+ .map(x -> x)
+ .keyBy(x -> x)
+ .process(
+ new StatefulProcessFunction<
+ Integer>() {})
+ .setParallelism(4)
+ .filter(x -> x < maxRound)
+ .setParallelism(1)),
DataStreamList.of(
reducer.getSideOutput(
TwoInputReducePerRoundOperator.OUTPUT_TAG)),
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index 6d80f23..f3f2272 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -38,6 +38,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap;
import org.apache.flink.test.iteration.operators.OutputRecord;
import org.apache.flink.test.iteration.operators.ReduceAllRoundProcessFunction;
import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.StatefulProcessFunction;
import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
@@ -192,7 +193,16 @@ public class UnboundedStreamIterationITCase extends TestLogger {
new ReduceAllRoundProcessFunction(sync, maxRound));
return new IterationBodyResult(
DataStreamList.of(
- reducer.map(new IncrementEpochMap())
+ reducer.partitionCustom(
+ (k, numPartitions) -> k % numPartitions,
+ EpochRecord::getValue)
+ .map(x -> x)
+ .keyBy(EpochRecord::getValue)
+ .process(
+ new StatefulProcessFunction<
+ EpochRecord>() {})
+ .setParallelism(4)
+ .map(new IncrementEpochMap())
.setParallelism(numSources)),
DataStreamList.of(
reducer.getSideOutput(
@@ -234,10 +244,20 @@ public class UnboundedStreamIterationITCase extends TestLogger {
.process(
new TwoInputReduceAllRoundProcessFunction(
sync, maxRound));
+
+ SingleOutputStreamOperator<EpochRecord> feedbackStream =
+ reducer.partitionCustom(
+ (k, numPartitions) -> k % numPartitions,
+ EpochRecord::getValue)
+ .map(x -> x)
+ .keyBy(EpochRecord::getValue)
+ .process(new StatefulProcessFunction<EpochRecord>() {})
+ .setParallelism(4)
+ .map(new IncrementEpochMap())
+ .setParallelism(numSources);
+
return new IterationBodyResult(
- DataStreamList.of(
- reducer.map(new IncrementEpochMap())
- .setParallelism(numSources)),
+ DataStreamList.of(feedbackStream),
DataStreamList.of(
reducer.getSideOutput(
new OutputTag<OutputRecord<Integer>>(
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java
new file mode 100644
index 0000000..47415f5
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration.operators;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.util.Collector;
+
+/**
+ * This is a function that uses keyed state so that we could verify the correctness of using keyed
+ * stream inside the iteration.
+ */
+public class StatefulProcessFunction<T> extends KeyedProcessFunction<Integer, T, T> {
+
+ private ValueState<Integer> state;
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ this.state =
+ getRuntimeContext().getState(new ValueStateDescriptor<>("state", Integer.class));
+ }
+
+ @Override
+ public void processElement(T value, Context ctx, Collector<T> out) throws Exception {
+ if (state.value() == null) {
+ state.update(0);
+
+ // Trying registers a timer
+ ctx.timerService().registerEventTimeTimer(1000L);
+ } else {
+ state.update(state.value() + 1);
+ }
+
+ out.collect(value);
+ }
+}