You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/15 06:38:16 UTC

[flink-ml] 02/03: [hotfix][iteration] Return more fine-grained operator class for the WrapperFactory

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 5c69eeea18f5302f780512bb1e5fb3d898172989
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Nov 15 11:32:43 2021 +0800

    [hotfix][iteration] Return more fine-grained operator class for the WrapperFactory
---
 .../flink/iteration/operator/OperatorWrapper.java       |  3 +++
 .../iteration/operator/WrapperOperatorFactory.java      |  2 +-
 .../operator/allround/AllRoundOperatorWrapper.java      | 17 +++++++++++++++++
 .../operator/perround/PerRoundOperatorWrapper.java      | 17 +++++++++++++++++
 .../DraftExecutionEnvironmentSwitchWrapperTest.java     | 12 ++++++++++++
 .../ml/common/broadcast/operator/BroadcastWrapper.java  | 15 +++++++++++++++
 6 files changed, 65 insertions(+), 1 deletion(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
index 1d7bbbe..e7fd379 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java
@@ -35,6 +35,9 @@ public interface OperatorWrapper<T, R> extends Serializable {
             StreamOperatorParameters<R> operatorParameters,
             StreamOperatorFactory<T> operatorFactory);
 
+    Class<? extends StreamOperator> getStreamOperatorClass(
+            ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory);
+
     <KEY> KeySelector<R, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector);
 
     StreamPartitioner<R> wrapStreamPartitioner(StreamPartitioner<T> streamPartitioner);
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
index dfb325a..c6a9a06 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java
@@ -48,7 +48,7 @@ public class WrapperOperatorFactory<OUT>
 
     @Override
     public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) {
-        return AbstractWrapperOperator.class;
+        return wrapper.getStreamOperatorClass(classLoader, operatorFactory);
     }
 
     @VisibleForTesting
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
index c28acd1..0cf54c0 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
@@ -57,6 +57,23 @@ public class AllRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR
     }
 
     @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(
+            ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+        Class<? extends StreamOperator> operatorClass =
+                operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+        if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return OneInputAllRoundWrapperOperator.class;
+        } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return TwoInputAllRoundWrapperOperator.class;
+        } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return MultipleInputAllRoundWrapperOperator.class;
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported operator class for all-round wrapper: " + operatorClass);
+        }
+    }
+
+    @Override
     public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector(
             KeySelector<T, KEY> keySelector) {
         return new ProxyKeySelector<>(keySelector);
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
index ffa2221..87ee6aa 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
@@ -57,6 +57,23 @@ public class PerRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR
     }
 
     @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(
+            ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+        Class<? extends StreamOperator> operatorClass =
+                operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+        if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return OneInputPerRoundWrapperOperator.class;
+        } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return TwoInputPerRoundWrapperOperator.class;
+        } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return MultipleInputPerRoundWrapperOperator.class;
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported operator class for all-round wrapper: " + operatorClass);
+        }
+    }
+
+    @Override
     public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector(
             KeySelector<T, KEY> keySelector) {
         return new ProxyKeySelector<>(keySelector);
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
index 86df0fc..b544003 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java
@@ -113,6 +113,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger {
         }
 
         @Override
+        public Class<? extends StreamOperator> getStreamOperatorClass(
+                ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+            return StreamMap.class;
+        }
+
+        @Override
         public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
             return keySelector;
         }
@@ -142,6 +148,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger {
         }
 
         @Override
+        public Class<? extends StreamOperator> getStreamOperatorClass(
+                ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+            return StreamFilter.class;
+        }
+
+        @Override
         public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
             return keySelector;
         }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
index 2e3f88d..2a18c85 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
@@ -75,6 +75,21 @@ public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
     }
 
     @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(
+            ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) {
+        Class<? extends StreamOperator> operatorClass =
+                operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+        if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return OneInputBroadcastWrapperOperator.class;
+        } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return TwoInputBroadcastWrapperOperator.class;
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported operator class for with-broadcast wrapper: " + operatorClass);
+        }
+    }
+
+    @Override
     public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) {
         return keySelector;
     }