You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/16 05:29:10 UTC
[flink-ml] branch master updated: [FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 01950cb [FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter
01950cb is described below
commit 01950cb967d94c6aceac7f080085898361a07a08
Author: Dong Lin <li...@gmail.com>
AuthorDate: Sun Nov 14 20:54:54 2021 +0800
[FLINK-22915][FLIP-173] Updates the static load(...) method of Stage subclasses to take StreamExecutionEnvironment as parameter
---
flink-ml-api/pom.xml | 7 +++++++
.../java/org/apache/flink/ml/api/core/Pipeline.java | 5 +++--
.../org/apache/flink/ml/api/core/PipelineModel.java | 7 +++++--
.../main/java/org/apache/flink/ml/api/core/Stage.java | 9 +++++----
.../java/org/apache/flink/ml/util/ReadWriteUtils.java | 15 ++++++++++-----
.../org/apache/flink/ml/api/core/ExampleStages.java | 8 ++++----
.../java/org/apache/flink/ml/api/core/PipelineTest.java | 4 ++--
.../java/org/apache/flink/ml/api/core/StageTest.java | 17 +++++++++++------
8 files changed, 47 insertions(+), 25 deletions(-)
diff --git a/flink-ml-api/pom.xml b/flink-ml-api/pom.xml
index ddfc659..b828457 100644
--- a/flink-ml-api/pom.xml
+++ b/flink-ml-api/pom.xml
@@ -41,6 +41,13 @@ under the License.
<dependency>
<groupId>org.apache.flink</groupId>
+ <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+ <version>${flink.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index f1e5d0c..94df9f8 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import java.io.IOException;
@@ -111,8 +112,8 @@ public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
ReadWriteUtils.savePipeline(this, stages, path);
}
- public static Pipeline load(String path) throws IOException {
- return new Pipeline(ReadWriteUtils.loadPipeline(path, Pipeline.class.getName()));
+ public static Pipeline load(StreamExecutionEnvironment env, String path) throws IOException {
+ return new Pipeline(ReadWriteUtils.loadPipeline(env, path, Pipeline.class.getName()));
}
/**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
index 45bb757..cf11beb 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import java.io.IOException;
@@ -72,8 +73,10 @@ public final class PipelineModel implements Model<PipelineModel> {
ReadWriteUtils.savePipeline(this, stages, path);
}
- public static PipelineModel load(String path) throws IOException {
- return new PipelineModel(ReadWriteUtils.loadPipeline(path, PipelineModel.class.getName()));
+ public static PipelineModel load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ return new PipelineModel(
+ ReadWriteUtils.loadPipeline(env, path, PipelineModel.class.getName()));
}
/**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
index 168599b..fa12186 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -31,14 +31,15 @@ import java.io.Serializable;
*
* <p>Each stage is with parameters, and requires a public empty constructor for restoration.
*
+ * <p>NOTE: every Stage subclass should implement a static method with signature {@code static T
+ * load(StreamExecutionEnvironment env, String path)}, where {@code T} refers to the concrete
+ * subclass. This static method should instantiate a new stage instance based on the data read from
+ * the given path.
+ *
* @param <T> The class type of the Stage implementation itself.
*/
@PublicEvolving
public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
/** Saves this stage to the given path. */
void save(String path) throws IOException;
-
- // NOTE: every Stage subclass should implement a static method with signature "static T
- // load(String path)", where T refers to the concrete subclass. This static method should
- // instantiate a new stage instance based on the data from the given path.
}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 283c1e5..68378cd 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.util;
import org.apache.flink.ml.api.core.Stage;
import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -183,11 +184,13 @@ public class ReadWriteUtils {
* <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
* match the className of the previously saved Pipeline or PipelineModel.
*
+ * @param env A StreamExecutionEnvironment instance.
* @param path The parent directory to load the pipeline metadata and its stages.
* @param expectedClassName The expected class name of the pipeline.
* @return A list of stages.
*/
- public static List<Stage<?>> loadPipeline(String path, String expectedClassName)
+ public static List<Stage<?>> loadPipeline(
+ StreamExecutionEnvironment env, String path, String expectedClassName)
throws IOException {
Map<String, ?> metadata = loadMetadata(path, expectedClassName);
int numStages = (Integer) metadata.get("numStages");
@@ -195,7 +198,7 @@ public class ReadWriteUtils {
for (int i = 0; i < numStages; i++) {
String stagePath = getPathForPipelineStage(i, numStages, path);
- stages.add(loadStage(stagePath));
+ stages.add(loadStage(env, stagePath));
}
return stages;
}
@@ -253,18 +256,20 @@ public class ReadWriteUtils {
*
* <p>Required: the stage class must have a static load() method.
*
+ * @param env A StreamExecutionEnvironment instance.
* @param path The parent directory of the stage metadata file.
* @return An instance of Stage.
*/
- public static Stage<?> loadStage(String path) throws IOException {
+ public static Stage<?> loadStage(StreamExecutionEnvironment env, String path)
+ throws IOException {
Map<String, ?> metadata = loadMetadata(path, "");
String className = (String) metadata.get("className");
try {
Class<?> clazz = Class.forName(className);
- Method method = clazz.getMethod("load", String.class);
+ Method method = clazz.getMethod("load", StreamExecutionEnvironment.class, String.class);
method.setAccessible(true);
- return (Stage<?>) method.invoke(null, path);
+ return (Stage<?>) method.invoke(null, env, path);
} catch (NoSuchMethodException e) {
String methodName = String.format("%s::load(String)", className);
throw new RuntimeException(
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
index ba04006..14b7d50 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
@@ -124,13 +124,12 @@ public class ExampleStages {
}
}
- public static SumModel load(String path) throws IOException {
+ public static SumModel load(StreamExecutionEnvironment env, String path)
+ throws IOException {
SumModel model = ReadWriteUtils.loadStageParam(path);
File dataFile = Paths.get(path, "data", "delta").toFile();
try (DataInputStream inputStream = new DataInputStream(new FileInputStream(dataFile))) {
- StreamExecutionEnvironment env =
- StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
return model.setModelData(modelData);
@@ -220,7 +219,8 @@ public class ExampleStages {
ReadWriteUtils.saveMetadata(this, path);
}
- public static SumEstimator load(String path) throws IOException {
+ public static SumEstimator load(StreamExecutionEnvironment env, String path)
+ throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index b23bd5b..8792455 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -83,7 +83,7 @@ public class PipelineTest extends AbstractTestBase {
Path tempDir = Files.createTempDirectory("PipelineTest");
String path = Paths.get(tempDir.toString(), "testPipelineModelSaveLoad").toString();
model.save(path);
- Model<?> loadedModel = PipelineModel.load(path);
+ Model<?> loadedModel = PipelineModel.load(env, path);
// Executes the loaded PipelineModel and verifies that it produces the expected output.
executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
@@ -107,7 +107,7 @@ public class PipelineTest extends AbstractTestBase {
Path tempDir = Files.createTempDirectory("PipelineTest");
String path = Paths.get(tempDir.toString(), "testPipeline").toString();
estimator.save(path);
- Estimator<?, ?> loadedEstimator = Pipeline.load(path);
+ Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
// Executes the loaded Pipeline and verifies that it produces the expected output.
executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
index 88b48b1..51de9d8 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
@@ -35,6 +35,7 @@ import org.apache.flink.ml.param.StringParam;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.junit.Assert;
import org.junit.Test;
@@ -121,7 +122,7 @@ public class StageTest {
ReadWriteUtils.saveMetadata(this, path);
}
- public static MyStage load(String path) throws IOException {
+ public static MyStage load(StreamExecutionEnvironment env, String path) throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
@@ -167,7 +168,8 @@ public class StageTest {
// Saves and loads the given stage. And verifies that the loaded stage has same parameter values
// as the original stage.
private static Stage<?> validateStageSaveLoad(
- Stage<?> stage, Map<String, Object> paramOverrides) throws IOException {
+ StreamExecutionEnvironment env, Stage<?> stage, Map<String, Object> paramOverrides)
+ throws IOException {
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = stage.getParam(entry.getKey());
ReadWriteUtils.setStageParam(stage, param, entry.getValue());
@@ -183,7 +185,7 @@ public class StageTest {
// This is expected.
}
- Stage<?> loadedStage = ReadWriteUtils.loadStage(path);
+ Stage<?> loadedStage = ReadWriteUtils.loadStage(env, path);
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = loadedStage.getParam(entry.getKey());
Assert.assertEquals(entry.getValue(), loadedStage.get(param));
@@ -307,26 +309,29 @@ public class StageTest {
@Test
public void testStageSaveLoad() throws IOException {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
- Stage<?> loadedStage = validateStageSaveLoad(stage, Collections.emptyMap());
+ Stage<?> loadedStage = validateStageSaveLoad(env, stage, Collections.emptyMap());
Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM));
}
@Test
public void testStageSaveLoadWithParamOverrides() throws IOException {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
Stage<?> loadedStage =
- validateStageSaveLoad(stage, Collections.singletonMap("intParam", 10));
+ validateStageSaveLoad(env, stage, Collections.singletonMap("intParam", 10));
Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM));
}
@Test
public void testStageLoadWithoutLoadMethod() throws IOException {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
MyStageWithoutLoad stage = new MyStageWithoutLoad();
try {
- validateStageSaveLoad(stage, Collections.emptyMap());
+ validateStageSaveLoad(env, stage, Collections.emptyMap());
Assert.fail("Expected RuntimeException");
} catch (RuntimeException e) {
Assert.assertTrue(e.getMessage().contains("not implemented"));