You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/15 06:38:16 UTC
[flink-ml] 02/03: [hotfix][iteration] Return more fine-grained operator class for the WrapperFactory
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 5c69eeea18f5302f780512bb1e5fb3d898172989
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Nov 15 11:32:43 2021 +0800
[hotfix][iteration] Return more fine-grained operator class for the WrapperFactory
---
.../flink/iteration/operator/OperatorWrapper.java | 3 +++
.../iteration/operator/WrapperOperatorFactory.java | 2 +-
.../operator/allround/AllRoundOperatorWrapper.java | 17 +++++++++++++++++
.../operator/perround/PerRoundOperatorWrapper.java | 17 +++++++++++++++++
.../DraftExecutionEnvironmentSwitchWrapperTest.java | 12 ++++++++++++
.../ml/common/broadcast/operator/BroadcastWrapper.java | 15 +++++++++++++++
6 files changed, 65 insertions(+), 1 deletion(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
index 1d7bbbe..e7fd379 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
@@ -35,6 +35,9 @@ public interface OperatorWrapper<T, R> extends Serializable {
StreamOperatorParameters<R> operatorParameters,
StreamOperatorFactory<T> operatorFactory);
+ Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory);
+
<KEY> KeySelector<R, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector);
StreamPartitioner<R> wrapStreamPartitioner(StreamPartitioner<T> streamPartitioner);
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
index dfb325a..c6a9a06 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
@@ -48,7 +48,7 @@ public class WrapperOperatorFactory<OUT>
@Override
public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) {
- return AbstractWrapperOperator.class;
+ return wrapper.getStreamOperatorClass(classLoader, operatorFactory);
}
@VisibleForTesting
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
index c28acd1..0cf54c0 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
@@ -57,6 +57,23 @@ public class AllRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR
}
@Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+ Class<? extends StreamOperator> operatorClass =
+ operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+ if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return OneInputAllRoundWrapperOperator.class;
+ } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return TwoInputAllRoundWrapperOperator.class;
+ } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return MultipleInputAllRoundWrapperOperator.class;
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported operator class for all-round wrapper: " + operatorClass);
+ }
+ }
+
+ @Override
public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector(
KeySelector<T, KEY> keySelector) {
return new ProxyKeySelector<>(keySelector);
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
index ffa2221..87ee6aa 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
@@ -57,6 +57,23 @@ public class PerRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR
}
@Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+ Class<? extends StreamOperator> operatorClass =
+ operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+ if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return OneInputPerRoundWrapperOperator.class;
+ } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return TwoInputPerRoundWrapperOperator.class;
+ } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return MultipleInputPerRoundWrapperOperator.class;
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported operator class for all-round wrapper: " + operatorClass);
+ }
+ }
+
+ @Override
public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector(
KeySelector<T, KEY> keySelector) {
return new ProxyKeySelector<>(keySelector);
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
index 86df0fc..b544003 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
@@ -113,6 +113,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger {
}
@Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+ return StreamMap.class;
+ }
+
+ @Override
public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
return keySelector;
}
@@ -142,6 +148,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger {
}
@Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+ return StreamFilter.class;
+ }
+
+ @Override
public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
return keySelector;
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
index 2e3f88d..2a18c85 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
@@ -75,6 +75,21 @@ public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
}
@Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(
+ ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+ Class<? extends StreamOperator> operatorClass =
+ operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+ if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return OneInputBroadcastWrapperOperator.class;
+ } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return TwoInputBroadcastWrapperOperator.class;
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported operator class for with-broadcast wrapper: " + operatorClass);
+ }
+ }
+
+ @Override
public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
return keySelector;
}