You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/16 03:27:00 UTC

[flink-ml] branch master updated: [FLINK-22915][FLIP-173] Updates Model::setModelData(...) to return the Model instance itself

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 030542a  [FLINK-22915][FLIP-173] Updates Model::setModelData(...) to return the Model instance itself
030542a is described below

commit 030542a9d5db064edfb4fe775c0bee6d61699f52
Author: Dong Lin <li...@gmail.com>
AuthorDate: Mon Nov 15 16:12:57 2021 +0800

    [FLINK-22915][FLIP-173] Updates Model::setModelData(...) to return the Model instance itself
    
    This closes #33.
---
 .../src/main/java/org/apache/flink/ml/api/core/Model.java |  2 +-
 .../java/org/apache/flink/ml/api/core/ExampleStages.java  | 10 ++++------
 .../java/org/apache/flink/ml/api/core/PipelineTest.java   | 15 +++++----------
 3 files changed, 10 insertions(+), 17 deletions(-)

diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
index 8caffe3..e24664b 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
@@ -35,7 +35,7 @@ public interface Model<T extends Model<T>> extends Transformer<T> {
      *
      * @param inputs a list of tables
      */
-    default void setModelData(Table... inputs) {
+    default T setModelData(Table... inputs) {
         throw new UnsupportedOperationException("this operation is not supported");
     }
 
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
index 2e4b4c2..ba04006 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
@@ -89,11 +89,12 @@ public class ExampleStages {
         }
 
         @Override
-        public void setModelData(Table... inputs) {
+        public SumModel setModelData(Table... inputs) {
             StreamTableEnvironment tEnv =
                     (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
 
             modelData = tEnv.toDataStream(inputs[0], Integer.class);
+            return this;
         }
 
         @Override
@@ -132,8 +133,7 @@ public class ExampleStages {
                         StreamExecutionEnvironment.getExecutionEnvironment();
                 StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
                 Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
-                model.setModelData(modelData);
-                return model;
+                return model.setModelData(modelData);
             }
         }
     }
@@ -209,9 +209,7 @@ public class ExampleStages {
                             .setParallelism(1);
             try {
                 SumModel model = new SumModel();
-                model.setModelData(tEnv.fromDataStream(modelData));
-
-                return model;
+                return model.setModelData(tEnv.fromDataStream(modelData));
             } catch (Exception e) {
                 throw new RuntimeException(e);
             }
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index 74f9c65..b23bd5b 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -69,12 +69,9 @@ public class PipelineTest extends AbstractTestBase {
         StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
         // Builds a PipelineModel that increments input value by 60. This PipelineModel consists of
         // three stages where each stage increments input value by 10, 20, and 30 respectively.
-        SumModel modelA = new SumModel();
-        modelA.setModelData(tEnv.fromValues(10));
-        SumModel modelB = new SumModel();
-        modelB.setModelData(tEnv.fromValues(20));
-        SumModel modelC = new SumModel();
-        modelC.setModelData(tEnv.fromValues(30));
+        SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
+        SumModel modelB = new SumModel().setModelData(tEnv.fromValues(20));
+        SumModel modelC = new SumModel().setModelData(tEnv.fromValues(30));
 
         List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
         Model<?> model = new PipelineModel(stages);
@@ -97,10 +94,8 @@ public class PipelineTest extends AbstractTestBase {
         StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
         // Builds a Pipeline that consists of a Model, an Estimator, and a model.
-        SumModel modelA = new SumModel();
-        modelA.setModelData(tEnv.fromValues(10));
-        SumModel modelB = new SumModel();
-        modelB.setModelData(tEnv.fromValues(30));
+        SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
+        SumModel modelB = new SumModel().setModelData(tEnv.fromValues(30));
 
         List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(), modelB);
         Estimator<?, ?> estimator = new Pipeline(stages);