You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/16 05:29:10 UTC

[flink-ml] branch master updated: [FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 01950cb  [FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter
01950cb is described below

commit 01950cb967d94c6aceac7f080085898361a07a08
Author: Dong Lin <li...@gmail.com>
AuthorDate: Sun Nov 14 20:54:54 2021 +0800

    [FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter
---
 flink-ml-api/pom.xml                                    |  7 +++++++
 .../java/org/apache/flink/ml/api/core/Pipeline.java     |  5 +++--
 .../org/apache/flink/ml/api/core/PipelineModel.java     |  7 +++++--
 .../main/java/org/apache/flink/ml/api/core/Stage.java   |  9 +++++----
 .../java/org/apache/flink/ml/util/ReadWriteUtils.java   | 15 ++++++++++-----
 .../org/apache/flink/ml/api/core/ExampleStages.java     |  8 ++++----
 .../java/org/apache/flink/ml/api/core/PipelineTest.java |  4 ++--
 .../java/org/apache/flink/ml/api/core/StageTest.java    | 17 +++++++++++------
 8 files changed, 47 insertions(+), 25 deletions(-)

diff --git a/flink-ml-api/pom.xml b/flink-ml-api/pom.xml
index ddfc659..b828457 100644
--- a/flink-ml-api/pom.xml
+++ b/flink-ml-api/pom.xml
@@ -41,6 +41,13 @@ under the License.
 
     <dependency>
       <groupId>org.apache.flink</groupId>
+      <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>provided</scope>
+    </dependency>
+
+    <dependency>
+      <groupId>org.apache.flink</groupId>
       <artifactId>flink-table-planner_${scala.binary.version}</artifactId>
       <version>${flink.version}</version>
       <scope>test</scope>
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index f1e5d0c..94df9f8 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 
 import java.io.IOException;
@@ -111,8 +112,8 @@ public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
         ReadWriteUtils.savePipeline(this, stages, path);
     }
 
-    public static Pipeline load(String path) throws IOException {
-        return new Pipeline(ReadWriteUtils.loadPipeline(path, Pipeline.class.getName()));
+    public static Pipeline load(StreamExecutionEnvironment env, String path) throws IOException {
+        return new Pipeline(ReadWriteUtils.loadPipeline(env, path, Pipeline.class.getName()));
     }
 
     /**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
index 45bb757..cf11beb 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 
 import java.io.IOException;
@@ -72,8 +73,10 @@ public final class PipelineModel implements Model<PipelineModel> {
         ReadWriteUtils.savePipeline(this, stages, path);
     }
 
-    public static PipelineModel load(String path) throws IOException {
-        return new PipelineModel(ReadWriteUtils.loadPipeline(path, PipelineModel.class.getName()));
+    public static PipelineModel load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        return new PipelineModel(
+                ReadWriteUtils.loadPipeline(env, path, PipelineModel.class.getName()));
     }
 
     /**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
index 168599b..fa12186 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -31,14 +31,15 @@ import java.io.Serializable;
  *
  * <p>Each stage is with parameters, and requires a public empty constructor for restoration.
  *
+ * <p>NOTE: every Stage subclass should implement a static method with signature {@code static T
+ * load(StreamExecutionEnvironment env, String path)}, where {@code T} refers to the concrete
+ * subclass. This static method should instantiate a new stage instance based on the data read from
+ * the given path.
+ *
  * @param <T> The class type of the Stage implementation itself.
  */
 @PublicEvolving
 public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
     /** Saves this stage to the given path. */
     void save(String path) throws IOException;
-
-    // NOTE: every Stage subclass should implement a static method with signature "static T
-    // load(String path)", where T refers to the concrete subclass. This static method should
-    // instantiate a new stage instance based on the data from the given path.
 }
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 283c1e5..68378cd 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.util;
 
 import org.apache.flink.ml.api.core.Stage;
 import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.util.InstantiationUtil;
 
 import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -183,11 +184,13 @@ public class ReadWriteUtils {
      * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
      * match the className of the previously saved Pipeline or PipelineModel.
      *
+     * @param env A StreamExecutionEnvironment instance.
      * @param path The parent directory to load the pipeline metadata and its stages.
      * @param expectedClassName The expected class name of the pipeline.
      * @return A list of stages.
      */
-    public static List<Stage<?>> loadPipeline(String path, String expectedClassName)
+    public static List<Stage<?>> loadPipeline(
+            StreamExecutionEnvironment env, String path, String expectedClassName)
             throws IOException {
         Map<String, ?> metadata = loadMetadata(path, expectedClassName);
         int numStages = (Integer) metadata.get("numStages");
@@ -195,7 +198,7 @@ public class ReadWriteUtils {
 
         for (int i = 0; i < numStages; i++) {
             String stagePath = getPathForPipelineStage(i, numStages, path);
-            stages.add(loadStage(stagePath));
+            stages.add(loadStage(env, stagePath));
         }
         return stages;
     }
@@ -253,18 +256,20 @@ public class ReadWriteUtils {
      *
      * <p>Required: the stage class must have a static load() method.
      *
+     * @param env A StreamExecutionEnvironment instance.
      * @param path The parent directory of the stage metadata file.
      * @return An instance of Stage.
      */
-    public static Stage<?> loadStage(String path) throws IOException {
+    public static Stage<?> loadStage(StreamExecutionEnvironment env, String path)
+            throws IOException {
         Map<String, ?> metadata = loadMetadata(path, "");
         String className = (String) metadata.get("className");
 
         try {
             Class<?> clazz = Class.forName(className);
-            Method method = clazz.getMethod("load", String.class);
+            Method method = clazz.getMethod("load", StreamExecutionEnvironment.class, String.class);
             method.setAccessible(true);
-            return (Stage<?>) method.invoke(null, path);
+            return (Stage<?>) method.invoke(null, env, path);
         } catch (NoSuchMethodException e) {
             String methodName = String.format("%s::load(String)", className);
             throw new RuntimeException(
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
index ba04006..14b7d50 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
@@ -124,13 +124,12 @@ public class ExampleStages {
             }
         }
 
-        public static SumModel load(String path) throws IOException {
+        public static SumModel load(StreamExecutionEnvironment env, String path)
+                throws IOException {
             SumModel model = ReadWriteUtils.loadStageParam(path);
             File dataFile = Paths.get(path, "data", "delta").toFile();
 
             try (DataInputStream inputStream = new DataInputStream(new FileInputStream(dataFile))) {
-                StreamExecutionEnvironment env =
-                        StreamExecutionEnvironment.getExecutionEnvironment();
                 StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
                 Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
                 return model.setModelData(modelData);
@@ -220,7 +219,8 @@ public class ExampleStages {
             ReadWriteUtils.saveMetadata(this, path);
         }
 
-        public static SumEstimator load(String path) throws IOException {
+        public static SumEstimator load(StreamExecutionEnvironment env, String path)
+                throws IOException {
             return ReadWriteUtils.loadStageParam(path);
         }
     }
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index b23bd5b..8792455 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -83,7 +83,7 @@ public class PipelineTest extends AbstractTestBase {
         Path tempDir = Files.createTempDirectory("PipelineTest");
         String path = Paths.get(tempDir.toString(), "testPipelineModelSaveLoad").toString();
         model.save(path);
-        Model<?> loadedModel = PipelineModel.load(path);
+        Model<?> loadedModel = PipelineModel.load(env, path);
 
         // Executes the loaded PipelineModel and verifies that it produces the expected output.
         executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
@@ -107,7 +107,7 @@ public class PipelineTest extends AbstractTestBase {
         Path tempDir = Files.createTempDirectory("PipelineTest");
         String path = Paths.get(tempDir.toString(), "testPipeline").toString();
         estimator.save(path);
-        Estimator<?, ?> loadedEstimator = Pipeline.load(path);
+        Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
 
         // Executes the loaded Pipeline and verifies that it produces the expected output.
         executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
index 88b48b1..51de9d8 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
@@ -35,6 +35,7 @@ import org.apache.flink.ml.param.StringParam;
 import org.apache.flink.ml.param.WithParams;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -121,7 +122,7 @@ public class StageTest {
             ReadWriteUtils.saveMetadata(this, path);
         }
 
-        public static MyStage load(String path) throws IOException {
+        public static MyStage load(StreamExecutionEnvironment env, String path) throws IOException {
             return ReadWriteUtils.loadStageParam(path);
         }
     }
@@ -167,7 +168,8 @@ public class StageTest {
     // Saves and loads the given stage. And verifies that the loaded stage has same parameter values
     // as the original stage.
     private static Stage<?> validateStageSaveLoad(
-            Stage<?> stage, Map<String, Object> paramOverrides) throws IOException {
+            StreamExecutionEnvironment env, Stage<?> stage, Map<String, Object> paramOverrides)
+            throws IOException {
         for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
             Param<?> param = stage.getParam(entry.getKey());
             ReadWriteUtils.setStageParam(stage, param, entry.getValue());
@@ -183,7 +185,7 @@ public class StageTest {
             // This is expected.
         }
 
-        Stage<?> loadedStage = ReadWriteUtils.loadStage(path);
+        Stage<?> loadedStage = ReadWriteUtils.loadStage(env, path);
         for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
             Param<?> param = loadedStage.getParam(entry.getKey());
             Assert.assertEquals(entry.getValue(), loadedStage.get(param));
@@ -307,26 +309,29 @@ public class StageTest {
 
     @Test
     public void testStageSaveLoad() throws IOException {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         MyStage stage = new MyStage();
         stage.set(stage.paramWithNullDefault, 1);
-        Stage<?> loadedStage = validateStageSaveLoad(stage, Collections.emptyMap());
+        Stage<?> loadedStage = validateStageSaveLoad(env, stage, Collections.emptyMap());
         Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM));
     }
 
     @Test
     public void testStageSaveLoadWithParamOverrides() throws IOException {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         MyStage stage = new MyStage();
         stage.set(stage.paramWithNullDefault, 1);
         Stage<?> loadedStage =
-                validateStageSaveLoad(stage, Collections.singletonMap("intParam", 10));
+                validateStageSaveLoad(env, stage, Collections.singletonMap("intParam", 10));
         Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM));
     }
 
     @Test
     public void testStageLoadWithoutLoadMethod() throws IOException {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         MyStageWithoutLoad stage = new MyStageWithoutLoad();
         try {
-            validateStageSaveLoad(stage, Collections.emptyMap());
+            validateStageSaveLoad(env, stage, Collections.emptyMap());
             Assert.fail("Expected RuntimeException");
         } catch (RuntimeException e) {
             Assert.assertTrue(e.getMessage().contains("not implemented"));