You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by jq...@apache.org on 2021/11/09 12:12:18 UTC

[flink-ml] 02/02: [FLINK-24354][FLIP-174] Improve the WithParams interface

This is an automated email from the ASF dual-hosted git repository.

jqin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9c44eef25970338fe32dcf77ce45efac74c4324f
Author: Dong Lin <li...@gmail.com>
AuthorDate: Sun Sep 26 21:40:59 2021 +0800

    [FLINK-24354][FLIP-174] Improve the WithParams interface
---
 flink-ml-api/pom.xml                               |  15 +
 .../org/apache/flink/ml/api/core/Pipeline.java     |  23 +-
 .../apache/flink/ml/api/core/PipelineModel.java    |  23 +-
 .../java/org/apache/flink/ml/api/core/Stage.java   |   2 +-
 .../org/apache/flink/ml/param/BooleanParam.java    |  35 ++
 .../apache/flink/ml/param/DoubleArrayParam.java    |  35 ++
 .../org/apache/flink/ml/param/DoubleParam.java     |  35 ++
 .../org/apache/flink/ml/param/FloatArrayParam.java |  35 ++
 .../java/org/apache/flink/ml/param/FloatParam.java |  32 ++
 .../org/apache/flink/ml/param/IntArrayParam.java   |  35 ++
 .../java/org/apache/flink/ml/param/IntParam.java   |  35 ++
 .../org/apache/flink/ml/param/LongArrayParam.java  |  35 ++
 .../java/org/apache/flink/ml/param/LongParam.java  |  32 ++
 .../main/java/org/apache/flink/ml/param/Param.java |  98 ++++++
 .../org/apache/flink/ml/param/ParamValidator.java  |  40 +++
 .../org/apache/flink/ml/param/ParamValidators.java |  98 ++++++
 .../apache/flink/ml/param/StringArrayParam.java    |  35 ++
 .../org/apache/flink/ml/param/StringParam.java     |  35 ++
 .../java/org/apache/flink/ml/param/WithParams.java | 135 ++++++++
 .../java/org/apache/flink/ml/util/ParamUtils.java  |  89 +++++
 .../org/apache/flink/ml/util/ReadWriteUtils.java   | 279 +++++++++++++++
 .../apache/flink/ml/api/core/ExampleStages.java    | 244 ++++++++++++++
 .../org/apache/flink/ml/api/core/PipelineTest.java | 202 +++++------
 .../org/apache/flink/ml/api/core/StageTest.java    | 375 +++++++++++++++++++++
 pom.xml                                            |   2 -
 25 files changed, 1863 insertions(+), 141 deletions(-)

diff --git a/flink-ml-api/pom.xml b/flink-ml-api/pom.xml
index 81fdcc7..ddfc659 100644
--- a/flink-ml-api/pom.xml
+++ b/flink-ml-api/pom.xml
@@ -38,6 +38,21 @@ under the License.
       <version>${flink.version}</version>
       <scope>provided</scope>
     </dependency>
+
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-table-planner_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-test-utils_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+
     <dependency>
       <groupId>org.apache.flink</groupId>
       <artifactId>flink-shaded-jackson</artifactId>
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index a5fed01..f1e5d0c 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -20,13 +20,17 @@ package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.table.api.Table;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 /**
  * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be
@@ -36,10 +40,11 @@ import java.util.List;
 public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
     private static final long serialVersionUID = 6384850154817512318L;
     private final List<Stage<?>> stages;
-    private final Params params = new Params();
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
 
     public Pipeline(List<Stage<?>> stages) {
         this.stages = stages;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
     /**
@@ -97,17 +102,17 @@ public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
     }
 
     @Override
-    public void save(String path) throws IOException {
-        throw new UnsupportedOperationException();
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
     }
 
-    public static Pipeline load(String path) throws IOException {
-        throw new UnsupportedOperationException();
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.savePipeline(this, stages, path);
     }
 
-    @Override
-    public Params getParams() {
-        return params;
+    public static Pipeline load(String path) throws IOException {
+        return new Pipeline(ReadWriteUtils.loadPipeline(path, Pipeline.class.getName()));
     }
 
     /**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
index 704fa8e..45bb757 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -20,12 +20,16 @@ package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.table.api.Table;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 /**
  * A PipelineModel acts as a Model. It consists of an ordered list of stages, each of which could be
@@ -35,10 +39,11 @@ import java.util.List;
 public final class PipelineModel implements Model<PipelineModel> {
     private static final long serialVersionUID = 6184950154217411318L;
     private final List<Stage<?>> stages;
-    private final Params params = new Params();
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
 
     public PipelineModel(List<Stage<?>> stages) {
         this.stages = stages;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
     /**
@@ -58,17 +63,17 @@ public final class PipelineModel implements Model<PipelineModel> {
     }
 
     @Override
-    public void save(String path) throws IOException {
-        throw new UnsupportedOperationException();
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
     }
 
-    public static PipelineModel load(String path) throws IOException {
-        throw new UnsupportedOperationException();
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.savePipeline(this, stages, path);
     }
 
-    @Override
-    public Params getParams() {
-        return params;
+    public static PipelineModel load(String path) throws IOException {
+        return new PipelineModel(ReadWriteUtils.loadPipeline(path, PipelineModel.class.getName()));
     }
 
     /**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
index 551c5e5..168599b 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -19,7 +19,7 @@
 package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.ml.api.misc.param.WithParams;
+import org.apache.flink.ml.param.WithParams;
 
 import java.io.IOException;
 import java.io.Serializable;
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java
new file mode 100644
index 0000000..dd96ebe
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the boolean parameter. */
+public class BooleanParam extends Param<Boolean> {
+
+    public BooleanParam(
+            String name,
+            String description,
+            Boolean defaultValue,
+            ParamValidator<Boolean> validator) {
+        super(name, Boolean.class, description, defaultValue, validator);
+    }
+
+    public BooleanParam(String name, String description, Boolean defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java
new file mode 100644
index 0000000..b86dd00
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the double array parameter. */
+public class DoubleArrayParam extends Param<Double[]> {
+
+    public DoubleArrayParam(
+            String name,
+            String description,
+            Double[] defaultValue,
+            ParamValidator<Double[]> validator) {
+        super(name, Double[].class, description, defaultValue, validator);
+    }
+
+    public DoubleArrayParam(String name, String description, Double[] defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java
new file mode 100644
index 0000000..f6d4911
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the double parameter. */
+public class DoubleParam extends Param<Double> {
+
+    public DoubleParam(
+            String name,
+            String description,
+            Double defaultValue,
+            ParamValidator<Double> validator) {
+        super(name, Double.class, description, defaultValue, validator);
+    }
+
+    public DoubleParam(String name, String description, Double defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java
new file mode 100644
index 0000000..4224557
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the float array parameter. */
+public class FloatArrayParam extends Param<Float[]> {
+
+    public FloatArrayParam(
+            String name,
+            String description,
+            Float[] defaultValue,
+            ParamValidator<Float[]> validator) {
+        super(name, Float[].class, description, defaultValue, validator);
+    }
+
+    public FloatArrayParam(String name, String description, Float[] defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java
new file mode 100644
index 0000000..0de890c
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the float parameter. */
+public class FloatParam extends Param<Float> {
+
+    public FloatParam(
+            String name, String description, Float defaultValue, ParamValidator<Float> validator) {
+        super(name, Float.class, description, defaultValue, validator);
+    }
+
+    public FloatParam(String name, String description, Float defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java
new file mode 100644
index 0000000..4f7c630
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the integer array parameter. */
+public class IntArrayParam extends Param<Integer[]> {
+
+    public IntArrayParam(
+            String name,
+            String description,
+            Integer[] defaultValue,
+            ParamValidator<Integer[]> validator) {
+        super(name, Integer[].class, description, defaultValue, validator);
+    }
+
+    public IntArrayParam(String name, String description, Integer[] defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java
new file mode 100644
index 0000000..4178e22
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the integer parameter. */
+public class IntParam extends Param<Integer> {
+
+    public IntParam(
+            String name,
+            String description,
+            Integer defaultValue,
+            ParamValidator<Integer> validator) {
+        super(name, Integer.class, description, defaultValue, validator);
+    }
+
+    public IntParam(String name, String description, Integer defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java
new file mode 100644
index 0000000..5e4fc47
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the long array parameter. */
+public class LongArrayParam extends Param<Long[]> {
+
+    public LongArrayParam(
+            String name,
+            String description,
+            Long[] defaultValue,
+            ParamValidator<Long[]> validator) {
+        super(name, Long[].class, description, defaultValue, validator);
+    }
+
+    public LongArrayParam(String name, String description, Long[] defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java
new file mode 100644
index 0000000..3fd7dd8
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the long parameter. */
+public class LongParam extends Param<Long> {
+
+    public LongParam(
+            String name, String description, Long defaultValue, ParamValidator<Long> validator) {
+        super(name, Long.class, description, defaultValue, validator);
+    }
+
+    public LongParam(String name, String description, Long defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
new file mode 100644
index 0000000..b7a1aef
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.util.ReadWriteUtils;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * Definition of a parameter, including name, class, description, default value and the validator.
+ *
+ * @param <T> The class type of the parameter value.
+ */
+@PublicEvolving
+public class Param<T> implements Serializable {
+    private static final long serialVersionUID = 4396556083935765299L;
+
+    public final String name;
+    public final Class<T> clazz;
+    public final String description;
+    public final T defaultValue;
+    public final ParamValidator<T> validator;
+
+    public Param(
+            String name,
+            Class<T> clazz,
+            String description,
+            T defaultValue,
+            ParamValidator<T> validator) {
+        this.name = name;
+        this.clazz = clazz;
+        this.description = description;
+        this.defaultValue = defaultValue;
+        this.validator = validator;
+
+        if (defaultValue != null && !validator.validate(defaultValue)) {
+            throw new IllegalArgumentException(
+                    "Parameter " + name + " is given an invalid value " + defaultValue);
+        }
+    }
+
+    /**
+     * Encodes the given object into a json-formatted string.
+     *
+     * @param value An object of class type T.
+     * @return A json-formatted string.
+     */
+    public String jsonEncode(T value) throws IOException {
+        return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value);
+    }
+
+    /**
+     * Decodes the given string into an object of class type T.
+     *
+     * @param json A json-formatted string.
+     * @return An object of class type T.
+     */
+    @SuppressWarnings("unchecked")
+    public T jsonDecode(String json) throws IOException {
+        return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (!(obj instanceof Param)) {
+            return false;
+        }
+        return ((Param<?>) obj).name.equals(name);
+    }
+
+    @Override
+    public int hashCode() {
+        return name.hashCode();
+    }
+
+    @Override
+    public String toString() {
+        return name;
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java
new file mode 100644
index 0000000..afdcd9a
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.Serializable;
+
+/**
+ * An interface to validate that a parameter value is valid.
+ *
+ * @param <T> The class type of the parameter value.
+ */
+@PublicEvolving
+public interface ParamValidator<T> extends Serializable {
+
+    /**
+     * Validate whether the parameter value is valid.
+     *
+     * @param value The parameter value.
+     * @return A boolean indicating whether the parameter value is valid.
+     */
+    boolean validate(T value);
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java
new file mode 100644
index 0000000..925ccb2
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+/** Factory methods for common validation functions on numerical values. */
+public class ParamValidators {
+
+    // Always return true.
+    public static <T> ParamValidator<T> alwaysTrue() {
+        return (value) -> true;
+    }
+
+    // Check if the parameter value is greater than lowerBound.
+    public static <T> ParamValidator<T> gt(double lowerBound) {
+        return (value) -> value != null && ((Number) value).doubleValue() > lowerBound;
+    }
+
+    // Check if the parameter value is greater than or equal to lowerBound.
+    public static <T> ParamValidator<T> gtEq(double lowerBound) {
+        return (value) -> value != null && ((Number) value).doubleValue() >= lowerBound;
+    }
+
+    // Check if the parameter value is less than upperBound.
+    public static <T> ParamValidator<T> lt(double upperBound) {
+        return (value) -> value != null && ((Number) value).doubleValue() < upperBound;
+    }
+
+    // Check if the parameter value is less than or equal to upperBound.
+    public static <T> ParamValidator<T> ltEq(double upperBound) {
+        return (value) -> value != null && ((Number) value).doubleValue() <= upperBound;
+    }
+
+    /**
+     * Check if the parameter value is in the range from lowerBound to upperBound.
+     *
+     * @param lowerInclusive if true, range includes value = lowerBound
+     * @param upperInclusive if true, range includes value = upperBound
+     */
+    public static <T> ParamValidator<T> inRange(
+            double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) {
+        return new ParamValidator<T>() {
+            @Override
+            public boolean validate(T obj) {
+                if (obj == null) {
+                    return false;
+                }
+                double value = ((Number) obj).doubleValue();
+                return (value >= lowerBound)
+                        && (value <= upperBound)
+                        && (lowerInclusive || value != lowerBound)
+                        && (upperInclusive || value != upperBound);
+            }
+        };
+    }
+
+    // Check if the parameter value is in the range [lowerBound, upperBound].
+    public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound) {
+        return inRange(lowerBound, upperBound, true, true);
+    }
+
+    // Check if the parameter value is in the array of allowed values.
+    public static <T> ParamValidator<T> inArray(T... allowed) {
+        return new ParamValidator<T>() {
+            @Override
+            public boolean validate(T value) {
+                return value != null && ArrayUtils.contains(allowed, value);
+            }
+        };
+    }
+
+    // Check if the parameter value is not null.
+    public static <T> ParamValidator<T> notNull() {
+        return new ParamValidator<T>() {
+            @Override
+            public boolean validate(T value) {
+                return value != null;
+            }
+        };
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java
new file mode 100644
index 0000000..5062463
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the string array parameter. */
+public class StringArrayParam extends Param<String[]> {
+
+    public StringArrayParam(
+            String name,
+            String description,
+            String[] defaultValue,
+            ParamValidator<String[]> validator) {
+        super(name, String[].class, description, defaultValue, validator);
+    }
+
+    public StringArrayParam(String name, String description, String[] defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java
new file mode 100644
index 0000000..1736354
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the string parameter. */
+public class StringParam extends Param<String> {
+
+    public StringParam(
+            String name,
+            String description,
+            String defaultValue,
+            ParamValidator<String> validator) {
+        super(name, String.class, description, defaultValue, validator);
+    }
+
+    public StringParam(String name, String description, String defaultValue) {
+        this(name, description, defaultValue, ParamValidators.alwaysTrue());
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java
new file mode 100644
index 0000000..f631c8e
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.util.ParamUtils;
+
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Interface for classes that take parameters. It provides APIs to set and get parameters.
+ *
+ * @param <T> The class type of WithParams implementation itself.
+ */
+@PublicEvolving
+public interface WithParams<T> {
+
+    /**
+     * Gets the parameter by its name.
+     *
+     * @param name The parameter name.
+     * @param <V> The class type of the parameter value.
+     * @return The parameter.
+     */
+    default <V> Param<V> getParam(String name) {
+        Optional<Param<?>> result =
+                getParamMap().keySet().stream().filter(param -> param.name.equals(name)).findAny();
+        return (Param<V>) result.orElse(null);
+    }
+
+    /**
+     * Sets the value of the parameter.
+     *
+     * @param param The parameter.
+     * @param value The parameter value.
+     * @return The WithParams instance itself.
+     */
+    @SuppressWarnings("unchecked")
+    default <V> T set(Param<V> param, V value) {
+        if (value != null && !param.clazz.isAssignableFrom(value.getClass())) {
+            throw new ClassCastException(
+                    "Parameter "
+                            + param.name
+                            + " is given a value with incompatible class "
+                            + value.getClass().getName());
+        }
+
+        if (!param.validator.validate(value)) {
+            if (value == null) {
+                throw new IllegalArgumentException(
+                        "Parameter " + param.name + "'s value should not be null");
+            } else {
+                throw new IllegalArgumentException(
+                        "Parameter "
+                                + param.name
+                                + " is given an invalid value "
+                                + value.toString());
+            }
+        }
+        getParamMap().put(param, value);
+        return (T) this;
+    }
+
+    /**
+     * Gets the value of the parameter.
+     *
+     * @param param The parameter.
+     * @param <V> The class type of the parameter value.
+     * @return The parameter value.
+     */
+    @SuppressWarnings("unchecked")
+    default <V> V get(Param<V> param) {
+        Map<Param<?>, Object> paramMap = getParamMap();
+        V value = (V) paramMap.get(param);
+
+        if (value == null && !param.validator.validate(value)) {
+            throw new IllegalArgumentException(
+                    "Parameter " + param.name + "'s value should not be null");
+        }
+
+        return value;
+    }
+
+    /**
+     * Returns a map which should contain value for every parameter that meets one of the following
+     * conditions.
+     *
+     * <p>1) set(...) has been called to set value for this parameter.
+     *
+     * <p>2) The parameter is a public final field of this WithParams instance. This includes fields
+     * inherited from its interfaces and super-classes.
+     *
+     * <p>The subclass which implements this interface could meet this requirement by returning a
+     * member field of the given map type, after having initialized this member field using the
+     * {@link ParamUtils#initializeMapWithDefaultValues(Map, WithParams)} method.
+     *
+     * @return A map which maps parameter definition to parameter value.
+     */
+    Map<Param<?>, Object> getParamMap();
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java
new file mode 100644
index 0000000..cdbe63d
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.WithParams;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Modifier;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/** Utility methods for reading and writing stages. */
+public class ParamUtils {
+    /**
+     * Updates the paramMap with default values of all public final Param-typed fields of the given
+     * instance. A parameter's value will not be updated if this parameter is already found in the
+     * map.
+     *
+     * <p>Note: This method should be called after all public final Param-typed fields of the given
+     * instance have been defined. A good choice is to call this method in the constructor of the
+     * given instance.
+     */
+    public static void initializeMapWithDefaultValues(
+            Map<Param<?>, Object> paramMap, WithParams<?> instance) {
+        List<Param<?>> defaultParams = getPublicFinalParamFields(instance);
+        for (Param<?> param : defaultParams) {
+            if (!paramMap.containsKey(param)) {
+                paramMap.put(param, param.defaultValue);
+            }
+        }
+    }
+
+    /**
+     * Finds all public final fields of the Param class type of the given object, including those
+     * fields inherited from its interfaces and super-classes, and returns those Param instances as
+     * a list.
+     *
+     * @param object the object whose public final Param-typed fields will be returned.
+     * @return a list of Param instances.
+     */
+    public static List<Param<?>> getPublicFinalParamFields(Object object) {
+        return getPublicFinalParamFields(object, object.getClass());
+    }
+
+    // A helper method that finds all public final fields of the Param class type of the given
+    // object and returns those Param instances as a list. The clazz specifies the object class.
+    private static List<Param<?>> getPublicFinalParamFields(Object object, Class<?> clazz) {
+        List<Param<?>> result = new ArrayList<>();
+        for (Field field : clazz.getDeclaredFields()) {
+            field.setAccessible(true);
+            if (Param.class.isAssignableFrom(field.getType())
+                    && Modifier.isPublic(field.getModifiers())
+                    && Modifier.isFinal(field.getModifiers())) {
+                try {
+                    result.add((Param<?>) field.get(object));
+                } catch (IllegalAccessException e) {
+                    throw new RuntimeException(
+                            "Failed to extract param from field " + field.getName(), e);
+                }
+            }
+        }
+
+        if (clazz.getSuperclass() != null) {
+            result.addAll(getPublicFinalParamFields(object, clazz.getSuperclass()));
+        }
+        for (Class<?> cls : clazz.getInterfaces()) {
+            result.addAll(getPublicFinalParamFields(object, cls));
+        }
+        return result;
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
new file mode 100644
index 0000000..283c1e5
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.util.InstantiationUtil;
+
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Utility methods for reading and writing stages. */
+public class ReadWriteUtils {
+    public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+    // A helper method that calls encodes the given parameter value to a json string. We can not
+    // call param.jsonEncode(value) directly because Param::jsonEncode(...) needs the actual type
+    // of the value.
+    private static <T> String jsonEncodeHelper(Param<T> param, Object value) throws IOException {
+        return param.jsonEncode((T) value);
+    }
+
+    // Converts Map<Param<?>, Object> to Map<String, String> which maps the parameter name to the
+    // string-encoded parameter value.
+    private static Map<String, String> jsonEncode(Map<Param<?>, Object> paramMap)
+            throws IOException {
+        Map<String, String> result = new HashMap<>(paramMap.size());
+        for (Map.Entry<Param<?>, Object> entry : paramMap.entrySet()) {
+            String json = jsonEncodeHelper(entry.getKey(), entry.getValue());
+            result.put(entry.getKey().name, json);
+        }
+        return result;
+    }
+
+    /**
+     * Saves the metadata of the given stage and the extra metadata to a file named `metadata` under
+     * the given path. The metadata of a stage includes the stage class name, parameter values etc.
+     *
+     * <p>Required: the metadata file under the given path should not exist.
+     *
+     * @param stage The stage instance.
+     * @param path The parent directory to save the stage metadata.
+     * @param extraMetadata The extra metadata to be saved.
+     */
+    public static void saveMetadata(Stage<?> stage, String path, Map<String, ?> extraMetadata)
+            throws IOException {
+        // Creates parent directories if not already created.
+        new File(path).mkdirs();
+
+        Map<String, Object> metadata = new HashMap<>(extraMetadata);
+        metadata.put("className", stage.getClass().getName());
+        metadata.put("timestamp", System.currentTimeMillis());
+        metadata.put("paramMap", jsonEncode(stage.getParamMap()));
+        // TODO: add version in the metadata.
+        String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata);
+
+        File metadataFile = new File(path, "metadata");
+        if (!metadataFile.createNewFile()) {
+            throw new IOException("File " + metadataFile.toString() + " already exists.");
+        }
+        try (BufferedWriter writer = new BufferedWriter(new FileWriter(metadataFile))) {
+            writer.write(metadataStr);
+        }
+    }
+
+    /**
+     * Saves the metadata of the given stage to a file named `metadata` under the given path. The
+     * metadata of a stage includes the stage class name, parameter values etc.
+     *
+     * <p>Required: the metadata file under the given path should not exist.
+     *
+     * @param stage The stage instance.
+     * @param path The parent directory to save the stage metadata.
+     */
+    public static void saveMetadata(Stage<?> stage, String path) throws IOException {
+        saveMetadata(stage, path, new HashMap<>());
+    }
+
+    /**
+     * Loads the metadata from the metadata file under the given path.
+     *
+     * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
+     * match the className of the previously saved stage.
+     *
+     * @param path The parent directory of the metadata file to read from.
+     * @param expectedClassName The expected class name of the stage.
+     * @return A map from metadata name to metadata value.
+     */
+    public static Map<String, ?> loadMetadata(String path, String expectedClassName)
+            throws IOException {
+        Path metadataPath = Paths.get(path, "metadata");
+        StringBuilder buffer = new StringBuilder();
+        try (BufferedReader br = new BufferedReader(new FileReader(metadataPath.toString()))) {
+            String line;
+            while ((line = br.readLine()) != null) {
+                if (!line.startsWith("#")) {
+                    buffer.append(line);
+                }
+            }
+        }
+
+        @SuppressWarnings("unchecked")
+        Map<String, ?> result = OBJECT_MAPPER.readValue(buffer.toString(), Map.class);
+
+        String className = (String) result.get("className");
+        if (!expectedClassName.isEmpty() && !expectedClassName.equals(className)) {
+            throw new RuntimeException(
+                    "Class name "
+                            + className
+                            + " does not match the expected class name "
+                            + expectedClassName
+                            + ".");
+        }
+
+        return result;
+    }
+
+    // Returns a string with value {parentPath}/stages/{stageIdx}, where the stageIdx is prefixed
+    // with zero or more `0` to have the same length as numStages. The resulting string can be
+    // used as the directory to save a stage of the Pipeline or PipelineModel.
+    private static String getPathForPipelineStage(int stageIdx, int numStages, String parentPath) {
+        String format = String.format("%%0%dd", String.valueOf(numStages).length());
+        String fileName = String.format(format, stageIdx);
+        return Paths.get(parentPath, "stages", fileName).toString();
+    }
+
+    /**
+     * Saves a Pipeline or PipelineModel with the given list of stages to the given path.
+     *
+     * @param pipeline A Pipeline or PipelineModel instance.
+     * @param stages A list of stages of the given pipeline.
+     * @param path The parent directory to save the pipeline metadata and its stages.
+     */
+    public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages, String path)
+            throws IOException {
+        // Creates parent directories if not already created.
+        new File(path).mkdirs();
+
+        Map<String, Object> extraMetadata = new HashMap<>();
+        extraMetadata.put("numStages", stages.size());
+        saveMetadata(pipeline, path, extraMetadata);
+
+        int numStages = stages.size();
+        for (int i = 0; i < numStages; i++) {
+            String stagePath = getPathForPipelineStage(i, numStages, path);
+            stages.get(i).save(stagePath);
+        }
+    }
+
+    /**
+     * Loads the stages of a Pipeline or PipelineModel from the given path.
+     *
+     * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
+     * match the className of the previously saved Pipeline or PipelineModel.
+     *
+     * @param path The parent directory to load the pipeline metadata and its stages.
+     * @param expectedClassName The expected class name of the pipeline.
+     * @return A list of stages.
+     */
+    public static List<Stage<?>> loadPipeline(String path, String expectedClassName)
+            throws IOException {
+        Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+        int numStages = (Integer) metadata.get("numStages");
+        List<Stage<?>> stages = new ArrayList<>(numStages);
+
+        for (int i = 0; i < numStages; i++) {
+            String stagePath = getPathForPipelineStage(i, numStages, path);
+            stages.add(loadStage(stagePath));
+        }
+        return stages;
+    }
+
+    // A helper method that sets stage's parameter value. We can not call stage.set(param, value)
+    // directly because stage::set(...) needs the actual type of the value.
+    public static <T> void setStageParam(Stage<?> stage, Param<T> param, Object value) {
+        stage.set(param, (T) value);
+    }
+
+    /**
+     * Loads the stage with the saved parameters from the given path. This method reads the metadata
+     * file under the given path, instantiates the stage using its no-argument constructor, and
+     * loads the stage with the paramMap from the metadata file.
+     *
+     * <p>Note: This method does not attempt to read model data from the given path. Caller needs to
+     * read model data from the given path if the stage has model data.
+     *
+     * <p>Required: the class with type T must have a no-argument constructor.
+     *
+     * @param path The parent directory of the stage metadata file.
+     * @param <T> The class type of the Stage subclass.
+     * @return An instance of class type T.
+     */
+    @SuppressWarnings("unchecked")
+    public static <T extends Stage<T>> T loadStageParam(String path) throws IOException {
+        Map<String, ?> metadata = loadMetadata(path, "");
+        String className = (String) metadata.get("className");
+        Map<String, String> paramMap = (Map<String, String>) metadata.get("paramMap");
+
+        try {
+            Class<T> clazz = (Class<T>) Class.forName(className);
+            T instance = InstantiationUtil.instantiate(clazz);
+
+            Map<String, Param<?>> nameToParam = new HashMap<>();
+            for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) {
+                nameToParam.put(param.name, param);
+            }
+
+            for (Map.Entry<String, String> entry : paramMap.entrySet()) {
+                Param<?> param = nameToParam.get(entry.getKey());
+                setStageParam(instance, param, param.jsonDecode(entry.getValue()));
+            }
+            return instance;
+        } catch (ClassNotFoundException e) {
+            throw new RuntimeException("Failed to load stage.", e);
+        }
+    }
+
+    /**
+     * Loads the stage from the given path by invoking the static load() method of the stage. The
+     * stage class name is read from the metadata file under the given path. The load() method is
+     * expected to construct the stage instance with the saved parameters, model data and other
+     * metadata if exists.
+     *
+     * <p>Required: the stage class must have a static load() method.
+     *
+     * @param path The parent directory of the stage metadata file.
+     * @return An instance of Stage.
+     */
+    public static Stage<?> loadStage(String path) throws IOException {
+        Map<String, ?> metadata = loadMetadata(path, "");
+        String className = (String) metadata.get("className");
+
+        try {
+            Class<?> clazz = Class.forName(className);
+            Method method = clazz.getMethod("load", String.class);
+            method.setAccessible(true);
+            return (Stage<?>) method.invoke(null, path);
+        } catch (NoSuchMethodException e) {
+            String methodName = String.format("%s::load(String)", className);
+            throw new RuntimeException(
+                    "Failed to load stage because the static method "
+                            + methodName
+                            + " is not implemented.",
+                    e);
+        } catch (ClassNotFoundException | IllegalAccessException | InvocationTargetException e) {
+            throw new RuntimeException("Failed to load stage.", e);
+        }
+    }
+}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
new file mode 100644
index 0000000..2e4b4c2
--- /dev/null
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.Map;
+
+/** Defines a few Stage subclasses to be used in unit tests. */
+public class ExampleStages {
+    /**
+     * A Model subclass that increments every value in the input stream by `delta` and outputs the
+     * resulting values.
+     */
+    public static class SumModel implements Model<SumModel> {
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+        private DataStream<Integer> modelData;
+
+        // This empty constructor is necessary in order for ModelA to be loaded by
+        // ReadWriteUtils.createStageWithParam
+        public SumModel() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+
+        @Override
+        public Table[] transform(Table... inputs) {
+            Assert.assertEquals(1, inputs.length);
+            StreamTableEnvironment tEnv =
+                    (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+            DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
+            DataStream<Integer> output =
+                    input.connect(modelData.broadcast())
+                            .transform(
+                                    "ApplyDeltaOperator",
+                                    BasicTypeInfo.INT_TYPE_INFO,
+                                    new ApplyDeltaOperator());
+
+            return new Table[] {tEnv.fromDataStream(output)};
+        }
+
+        @Override
+        public void setModelData(Table... inputs) {
+            StreamTableEnvironment tEnv =
+                    (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+            modelData = tEnv.toDataStream(inputs[0], Integer.class);
+        }
+
+        @Override
+        public void save(String path) throws IOException {
+            ReadWriteUtils.saveMetadata(this, path);
+
+            File dataDir = new File(path, "data");
+            if (!dataDir.mkdir()) {
+                throw new IOException("Directory " + dataDir.toString() + " already exists.");
+            }
+
+            File dataFile = new File(dataDir, "delta");
+            if (!dataFile.createNewFile()) {
+                throw new IOException("File " + dataFile.toString() + " already exists.");
+            }
+
+            int delta;
+            try {
+                delta = (Integer) IteratorUtils.toList(modelData.executeAndCollect()).get(0);
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+
+            try (DataOutputStream outputStream =
+                    new DataOutputStream(new FileOutputStream(dataFile))) {
+                outputStream.writeInt(delta);
+            }
+        }
+
+        public static SumModel load(String path) throws IOException {
+            SumModel model = ReadWriteUtils.loadStageParam(path);
+            File dataFile = Paths.get(path, "data", "delta").toFile();
+
+            try (DataInputStream inputStream = new DataInputStream(new FileInputStream(dataFile))) {
+                StreamExecutionEnvironment env =
+                        StreamExecutionEnvironment.getExecutionEnvironment();
+                StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+                Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
+                model.setModelData(modelData);
+                return model;
+            }
+        }
+    }
+
+    // Adds delta from the 2nd input to every element in the 1st input and returns the added values.
+    private static class ApplyDeltaOperator extends AbstractStreamOperator<Integer>
+            implements TwoInputStreamOperator<Integer, Integer, Integer> {
+        private ListState<Integer> unProcessedValues;
+        private BroadcastState<String, Integer> broadcastState = null;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            unProcessedValues =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<Integer>(
+                                            "unProcessedValues", Integer.class));
+            broadcastState =
+                    context.getOperatorStateStore()
+                            .getBroadcastState(
+                                    new MapStateDescriptor<String, Integer>(
+                                            "broadcastState", String.class, Integer.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Integer> record) throws Exception {
+            if (broadcastState.get("delta") == null) {
+                unProcessedValues.add(record.getValue());
+            } else {
+                output.collect(new StreamRecord<>(record.getValue() + broadcastState.get("delta")));
+            }
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Integer> record) throws Exception {
+            if (broadcastState.get("delta") != null) {
+                throw new IllegalStateException("Model data should have exactly one value");
+            }
+            broadcastState.put("delta", record.getValue());
+
+            for (Integer value : unProcessedValues.get()) {
+                output.collect(new StreamRecord<>(value + record.getValue()));
+            }
+            unProcessedValues.clear();
+        }
+    }
+
+    /**
+     * An Estimator subclass which calculates the sum of input values and instantiates a ModelA
+     * instance with delta=sum(inputs).
+     */
+    public static class SumEstimator implements Estimator<SumEstimator, SumModel> {
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+        public SumEstimator() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+
+        @Override
+        public SumModel fit(Table... inputs) {
+            Assert.assertEquals(1, inputs.length);
+            StreamTableEnvironment tEnv =
+                    (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+            DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
+            DataStream<Integer> modelData =
+                    input.transform("SumOperator", BasicTypeInfo.INT_TYPE_INFO, new SumOperator())
+                            .setParallelism(1);
+            try {
+                SumModel model = new SumModel();
+                model.setModelData(tEnv.fromDataStream(modelData));
+
+                return model;
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        @Override
+        public void save(String path) throws IOException {
+            ReadWriteUtils.saveMetadata(this, path);
+        }
+
+        public static SumEstimator load(String path) throws IOException {
+            return ReadWriteUtils.loadStageParam(path);
+        }
+    }
+
+    private static class SumOperator extends AbstractStreamOperator<Integer>
+            implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {
+        int sum = 0;
+
+        @Override
+        public void endInput() throws Exception {
+            output.collect(new StreamRecord<>(sum));
+        }
+
+        @Override
+        public void processElement(StreamRecord<Integer> input) throws Exception {
+            sum += input.getValue();
+        }
+    }
+}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index 6d46430..74f9c65 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -18,141 +18,103 @@
 
 package org.apache.flink.ml.api.core;
 
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.api.core.ExampleStages.SumEstimator;
+import org.apache.flink.ml.api.core.ExampleStages.SumModel;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
 
-import org.junit.Rule;
+import org.apache.commons.collections.IteratorUtils;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
-import java.io.IOException;
-import java.util.ArrayList;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Comparator;
 import java.util.List;
 
-/** Tests the behavior of {@link Pipeline}. */
-public class PipelineTest {
-    @Rule public ExpectedException thrown = ExpectedException.none();
+/** Tests the behavior of Pipeline and PipelineModel. */
+public class PipelineTest extends AbstractTestBase {
 
-    @Test
-    public void testPipelineBehavior() {
-        List<Stage<?>> stages = new ArrayList<>();
-        stages.add(new MockTransformer("a"));
-        stages.add(new MockEstimator("b"));
-        stages.add(new MockEstimator("c"));
-        stages.add(new MockTransformer("d"));
-
-        Pipeline pipeline = new Pipeline(stages);
-        assert describePipeline(pipeline.getStages()).equals("a_b_c_d");
-
-        PipelineModel pipelineModel = pipeline.fit(null, null);
-        assert describePipeline(pipelineModel.getStages()).equals("a_mb_mc_d");
-    }
-
-    private static String describePipeline(List<Stage<?>> stages) {
-        StringBuilder res = new StringBuilder();
-        for (Stage<?> s : stages) {
-            if (res.length() != 0) {
-                res.append("_");
-            }
-            res.append(((SelfDescribe) s).describe());
-        }
-        return res.toString();
-    }
-
-    /** Interface to describe a class with a string, only for pipeline test. */
-    private interface SelfDescribe {
-        ParamInfo<String> DESCRIPTION =
-                ParamInfoFactory.createParamInfo("description", String.class).build();
-
-        String describe();
-    }
-
-    /** Mock estimator for pipeline test. */
-    public static class MockEstimator implements Estimator<MockEstimator, MockModel>, SelfDescribe {
-        private final Params params = new Params();
+    // Executes the given stage and verifies that it produces the expected output.
+    private static void executeAndCheckOutput(
+            Stage<?> stage, List<Integer> input, List<Integer> expectedOutput) throws Exception {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        env.setParallelism(4);
 
-        public MockEstimator() {}
+        Table inputTable = tEnv.fromDataStream(env.fromCollection(input));
 
-        MockEstimator(String description) {
-            set(DESCRIPTION, description);
-        }
-
-        @Override
-        public MockModel fit(Table... inputs) {
-            return new MockModel("m" + describe());
-        }
-
-        @Override
-        public Params getParams() {
-            return params;
-        }
+        Table outputTable;
 
-        @Override
-        public String describe() {
-            return get(DESCRIPTION);
+        if (stage instanceof AlgoOperator) {
+            outputTable = ((AlgoOperator<?>) stage).transform(inputTable)[0];
+        } else {
+            Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
+            Model<?> model = estimator.fit(inputTable);
+            outputTable = model.transform(inputTable)[0];
         }
 
-        @Override
-        public void save(String path) throws IOException {}
+        List<Integer> output =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(outputTable, Integer.class).executeAndCollect());
+        compareResultCollections(expectedOutput, output, Comparator.naturalOrder());
     }
 
-    /** Mock transformer for pipeline test. */
-    public static class MockTransformer implements Transformer<MockTransformer>, SelfDescribe {
-        private final Params params = new Params();
-
-        public MockTransformer() {}
-
-        MockTransformer(String description) {
-            set(DESCRIPTION, description);
-        }
-
-        @Override
-        public Table[] transform(Table... inputs) {
-            return inputs;
-        }
-
-        @Override
-        public Params getParams() {
-            return params;
-        }
-
-        @Override
-        public String describe() {
-            return get(DESCRIPTION);
-        }
-
-        @Override
-        public void save(String path) throws IOException {}
+    @Test
+    public void testPipelineModel() throws Exception {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        // Builds a PipelineModel that increments input value by 60. This PipelineModel consists of
+        // three stages where each stage increments input value by 10, 20, and 30 respectively.
+        SumModel modelA = new SumModel();
+        modelA.setModelData(tEnv.fromValues(10));
+        SumModel modelB = new SumModel();
+        modelB.setModelData(tEnv.fromValues(20));
+        SumModel modelC = new SumModel();
+        modelC.setModelData(tEnv.fromValues(30));
+
+        List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
+        Model<?> model = new PipelineModel(stages);
+
+        // Executes the original PipelineModel and verifies that it produces the expected output.
+        executeAndCheckOutput(model, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
+
+        // Saves and loads the PipelineModel.
+        Path tempDir = Files.createTempDirectory("PipelineTest");
+        String path = Paths.get(tempDir.toString(), "testPipelineModelSaveLoad").toString();
+        model.save(path);
+        Model<?> loadedModel = PipelineModel.load(path);
+
+        // Executes the loaded PipelineModel and verifies that it produces the expected output.
+        executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
     }
 
-    /** Mock model for pipeline test. */
-    public static class MockModel implements Model<MockModel>, SelfDescribe {
-        private final Params params = new Params();
-
-        public MockModel() {}
-
-        MockModel(String description) {
-            set(DESCRIPTION, description);
-        }
-
-        @Override
-        public Table[] transform(Table... inputs) {
-            return inputs;
-        }
-
-        @Override
-        public Params getParams() {
-            return params;
-        }
-
-        @Override
-        public String describe() {
-            return get(DESCRIPTION);
-        }
-
-        @Override
-        public void save(String path) throws IOException {}
+    @Test
+    public void testPipeline() throws Exception {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        // Builds a Pipeline that consists of a Model, an Estimator, and a model.
+        SumModel modelA = new SumModel();
+        modelA.setModelData(tEnv.fromValues(10));
+        SumModel modelB = new SumModel();
+        modelB.setModelData(tEnv.fromValues(30));
+
+        List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(), modelB);
+        Estimator<?, ?> estimator = new Pipeline(stages);
+
+        // Executes the original Pipeline and verifies that it produces the expected output.
+        executeAndCheckOutput(estimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
+
+        // Saves and loads the Pipeline.
+        Path tempDir = Files.createTempDirectory("PipelineTest");
+        String path = Paths.get(tempDir.toString(), "testPipeline").toString();
+        estimator.save(path);
+        Estimator<?, ?> loadedEstimator = Pipeline.load(path);
+
+        // Executes the loaded Pipeline and verifies that it produces the expected output.
+        executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
     }
 }
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
new file mode 100644
index 0000000..9e03ddb
--- /dev/null
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
@@ -0,0 +1,375 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.FloatArrayParam;
+import org.apache.flink.ml.param.FloatParam;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.LongArrayParam;
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** Tests the behavior of Stage and WithParams. */
+public class StageTest {
+
+    // A WithParams subclass which has one parameter for each pre-defined parameter type.
+    private interface MyParams<T> extends WithParams<T> {
+        Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);
+
+        Param<Integer> INT_PARAM =
+                new IntParam("intParam", "Description", 1, ParamValidators.lt(100));
+
+        Param<Long> LONG_PARAM =
+                new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));
+
+        Param<Float> FLOAT_PARAM =
+                new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100));
+
+        Param<Double> DOUBLE_PARAM =
+                new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100));
+
+        Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5");
+
+        Param<Integer[]> INT_ARRAY_PARAM =
+                new IntArrayParam("intArrayParam", "Description", new Integer[] {6, 7});
+
+        Param<Long[]> LONG_ARRAY_PARAM =
+                new LongArrayParam(
+                        "longArrayParam",
+                        "Description",
+                        new Long[] {8L, 9L},
+                        ParamValidators.alwaysTrue());
+
+        Param<Float[]> FLOAT_ARRAY_PARAM =
+                new FloatArrayParam("floatArrayParam", "Description", new Float[] {10.0f, 11.0f});
+
+        Param<Double[]> DOUBLE_ARRAY_PARAM =
+                new DoubleArrayParam(
+                        "doubleArrayParam",
+                        "Description",
+                        new Double[] {12.0, 13.0},
+                        ParamValidators.alwaysTrue());
+
+        Param<String[]> STRING_ARRAY_PARAM =
+                new StringArrayParam("stringArrayParam", "Description", new String[] {"14", "15"});
+    }
+
+    /**
+     * A Stage subclass which inherits all parameters from MyParams and defines an extra parameter.
+     */
+    public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+        public final Param<Integer> extraIntParam =
+                new IntParam("extraIntParam", "Description", 20, ParamValidators.alwaysTrue());
+
+        public final Param<Integer> paramWithNullDefault =
+                new IntParam(
+                        "paramWithNullDefault",
+                        "Must be explicitly set with a non-null value",
+                        null,
+                        ParamValidators.notNull());
+
+        public MyStage() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+
+        @Override
+        public void save(String path) throws IOException {
+            ReadWriteUtils.saveMetadata(this, path);
+        }
+
+        public static MyStage load(String path) throws IOException {
+            return ReadWriteUtils.loadStageParam(path);
+        }
+    }
+
+    /** A Stage subclass without the static load() method. */
+    public static class MyStageWithoutLoad implements Stage<MyStage>, MyParams<MyStage> {
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+        public MyStageWithoutLoad() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public void save(String path) throws IOException {
+            ReadWriteUtils.saveMetadata(this, path);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+    }
+
+    // Asserts that m1 and m2 are equivalent.
+    private static void assertParamMapEquals(Map<Param<?>, Object> m1, Map<Param<?>, Object> m2) {
+        Assert.assertTrue(m1 != null && m2 != null);
+        Assert.assertEquals(m1.size(), m2.size());
+
+        for (Map.Entry<Param<?>, Object> entry : m1.entrySet()) {
+            Assert.assertTrue(m2.containsKey(entry.getKey()));
+            Object v1 = entry.getValue();
+            Object v2 = m2.get(entry.getKey());
+            if (v1 == null || v2 == null) {
+                Assert.assertTrue(v1 == null && v2 == null);
+            } else if (v1.getClass().isArray() && v2.getClass().isArray()) {
+                Assert.assertArrayEquals((Object[]) v1, (Object[]) v2);
+            } else {
+                Assert.assertEquals(v1, v2);
+            }
+        }
+    }
+
+    // Saves and loads the given stage. And verifies that the loaded stage has same parameter values
+    // as the original stage.
+    private static Stage<?> validateStageSaveLoad(
+            Stage<?> stage, Map<String, Object> paramOverrides) throws IOException {
+        for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
+            Param<?> param = stage.getParam(entry.getKey());
+            ReadWriteUtils.setStageParam(stage, param, entry.getValue());
+        }
+
+        String tempDir = Files.createTempDirectory("").toString();
+        String path = Paths.get(tempDir, "test").toString();
+        stage.save(path);
+        try {
+            stage.save(path);
+            Assert.fail("Expected IOException");
+        } catch (IOException e) {
+            // This is expected.
+        }
+
+        Stage<?> loadedStage = ReadWriteUtils.loadStage(path);
+        for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
+            Param<?> param = loadedStage.getParam(entry.getKey());
+            Assert.assertEquals(entry.getValue(), loadedStage.get(param));
+        }
+        assertParamMapEquals(stage.getParamMap(), loadedStage.getParamMap());
+        return loadedStage;
+    }
+
+    @Test
+    public void testParamSetValueWithName() {
+        MyStage stage = new MyStage();
+
+        Param<Integer> paramA = MyParams.INT_PARAM;
+        stage.set(paramA, 2);
+        Assert.assertEquals(2, (int) stage.get(paramA));
+
+        Param<Integer> paramB = stage.getParam("intParam");
+        stage.set(paramB, 3);
+        Assert.assertEquals(3, (int) stage.get(paramB));
+
+        Param<Integer> paramC = stage.getParam("extraIntParam");
+        stage.set(paramC, 50);
+        Assert.assertEquals(50, (int) stage.get(paramC));
+    }
+
+    @Test
+    public void testParamWithNullDefault() {
+        MyStage stage = new MyStage();
+        try {
+            stage.get(stage.paramWithNullDefault);
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (IllegalArgumentException e) {
+            Assert.assertTrue(e.getMessage().contains("should not be null"));
+        }
+
+        stage.set(stage.paramWithNullDefault, 3);
+        Assert.assertEquals(3, (int) stage.get(stage.paramWithNullDefault));
+    }
+
+    private static <T> void assertInvalidValue(Stage<?> stage, Param<T> param, T value) {
+        try {
+            stage.set(param, value);
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (IllegalArgumentException e) {
+            Assert.assertTrue(e.getMessage().contains("invalid value"));
+        }
+    }
+
+    private static <T> void assertInvalidClass(Stage<?> stage, Param<T> param, Object value) {
+        try {
+            stage.set(param, (T) value);
+            Assert.fail("Expected ClassCastException");
+        } catch (ClassCastException e) {
+            Assert.assertTrue(e.getMessage().contains("incompatible class"));
+        }
+    }
+
+    @Test
+    public void testParamSetInvalidValue() {
+        MyStage stage = new MyStage();
+        assertInvalidValue(stage, MyParams.INT_PARAM, 100);
+        assertInvalidValue(stage, MyParams.LONG_PARAM, 100L);
+        assertInvalidValue(stage, MyParams.FLOAT_PARAM, 100.0f);
+        assertInvalidValue(stage, MyParams.DOUBLE_PARAM, 100.0);
+        assertInvalidClass(stage, MyParams.INT_PARAM, "100");
+        assertInvalidClass(stage, MyParams.STRING_PARAM, 100);
+
+        Param<Integer> param = stage.getParam("stringParam");
+        assertInvalidClass(stage, param, 50);
+    }
+
+    @Test
+    public void testParamSetValidValue() {
+        MyStage stage = new MyStage();
+
+        stage.set(MyParams.BOOLEAN_PARAM, true);
+        Assert.assertEquals(true, stage.get(MyParams.BOOLEAN_PARAM));
+
+        stage.set(MyParams.INT_PARAM, 50);
+        Assert.assertEquals(50, (int) stage.get(MyParams.INT_PARAM));
+
+        stage.set(MyParams.LONG_PARAM, 50L);
+        Assert.assertEquals(50L, (long) stage.get(MyParams.LONG_PARAM));
+
+        stage.set(MyParams.FLOAT_PARAM, 50f);
+        Assert.assertEquals(50f, (float) stage.get(MyParams.FLOAT_PARAM), 0.0001);
+
+        stage.set(MyParams.DOUBLE_PARAM, 50.0);
+        Assert.assertEquals(50, (double) stage.get(MyParams.DOUBLE_PARAM), 0.0001);
+
+        stage.set(MyParams.STRING_PARAM, "50");
+        Assert.assertEquals("50", stage.get(MyParams.STRING_PARAM));
+
+        stage.set(MyParams.INT_ARRAY_PARAM, new Integer[] {50, 51});
+        Assert.assertArrayEquals(new Integer[] {50, 51}, stage.get(MyParams.INT_ARRAY_PARAM));
+
+        stage.set(MyParams.LONG_ARRAY_PARAM, new Long[] {50L, 51L});
+        Assert.assertArrayEquals(new Long[] {50L, 51L}, stage.get(MyParams.LONG_ARRAY_PARAM));
+
+        stage.set(MyParams.FLOAT_ARRAY_PARAM, new Float[] {50.0f, 51.0f});
+        Assert.assertArrayEquals(new Float[] {50.0f, 51.0f}, stage.get(MyParams.FLOAT_ARRAY_PARAM));
+
+        stage.set(MyParams.DOUBLE_ARRAY_PARAM, new Double[] {50.0, 51.0});
+        Assert.assertArrayEquals(new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_PARAM));
+
+        stage.set(MyParams.STRING_ARRAY_PARAM, new String[] {"50", "51"});
+        Assert.assertArrayEquals(new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_PARAM));
+    }
+
+    @Test
+    public void testStageSaveLoad() throws IOException {
+        MyStage stage = new MyStage();
+        stage.set(stage.paramWithNullDefault, 1);
+        Stage<?> loadedStage = validateStageSaveLoad(stage, Collections.emptyMap());
+        Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM));
+    }
+
+    @Test
+    public void testStageSaveLoadWithParamOverrides() throws IOException {
+        MyStage stage = new MyStage();
+        stage.set(stage.paramWithNullDefault, 1);
+        Stage<?> loadedStage =
+                validateStageSaveLoad(stage, Collections.singletonMap("intParam", 10));
+        Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM));
+    }
+
+    @Test
+    public void testStageLoadWithoutLoadMethod() throws IOException {
+        MyStageWithoutLoad stage = new MyStageWithoutLoad();
+        try {
+            validateStageSaveLoad(stage, Collections.emptyMap());
+            Assert.fail("Expected RuntimeException");
+        } catch (RuntimeException e) {
+            Assert.assertTrue(e.getMessage().contains("not implemented"));
+        }
+    }
+
+    @Test
+    public void testValidators() {
+        ParamValidator<Integer> gt = ParamValidators.gt(10);
+        Assert.assertFalse(gt.validate(null));
+        Assert.assertFalse(gt.validate(5));
+        Assert.assertFalse(gt.validate(10));
+        Assert.assertTrue(gt.validate(15));
+
+        ParamValidator<Integer> gtEq = ParamValidators.gtEq(10);
+        Assert.assertFalse(gtEq.validate(null));
+        Assert.assertFalse(gtEq.validate(5));
+        Assert.assertTrue(gtEq.validate(10));
+        Assert.assertTrue(gtEq.validate(15));
+
+        ParamValidator<Integer> lt = ParamValidators.lt(10);
+        Assert.assertFalse(lt.validate(null));
+        Assert.assertTrue(lt.validate(5));
+        Assert.assertFalse(lt.validate(10));
+        Assert.assertFalse(lt.validate(15));
+
+        ParamValidator<Integer> ltEq = ParamValidators.ltEq(10);
+        Assert.assertFalse(ltEq.validate(null));
+        Assert.assertTrue(ltEq.validate(5));
+        Assert.assertTrue(ltEq.validate(10));
+        Assert.assertFalse(ltEq.validate(15));
+
+        ParamValidator<Integer> inRangeInclusive = ParamValidators.inRange(5, 15);
+        Assert.assertFalse(inRangeInclusive.validate(null));
+        Assert.assertFalse(inRangeInclusive.validate(0));
+        Assert.assertTrue(inRangeInclusive.validate(5));
+        Assert.assertTrue(inRangeInclusive.validate(10));
+        Assert.assertTrue(inRangeInclusive.validate(15));
+        Assert.assertFalse(inRangeInclusive.validate(20));
+
+        ParamValidator<Integer> inRangeExclusive = ParamValidators.inRange(5, 15, false, false);
+        Assert.assertFalse(inRangeExclusive.validate(null));
+        Assert.assertFalse(inRangeExclusive.validate(0));
+        Assert.assertFalse(inRangeExclusive.validate(5));
+        Assert.assertTrue(inRangeExclusive.validate(10));
+        Assert.assertFalse(inRangeExclusive.validate(15));
+        Assert.assertFalse(inRangeExclusive.validate(20));
+
+        ParamValidator<Integer> inArray = ParamValidators.inArray(1, 2, 3);
+        Assert.assertFalse(inArray.validate(null));
+        Assert.assertTrue(inArray.validate(1));
+        Assert.assertFalse(inArray.validate(0));
+
+        ParamValidator<Integer> notNull = ParamValidators.notNull();
+        Assert.assertTrue(notNull.validate(5));
+        Assert.assertFalse(notNull.validate(null));
+    }
+}
diff --git a/pom.xml b/pom.xml
index 5eb1805..66530a0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -53,8 +53,6 @@ under the License.
 
   <modules>
     <module>flink-ml-api</module>
-    <module>flink-ml-lib</module>
-    <module>flink-ml-uber</module>
     <module>flink-ml-iteration</module>
     <module>flink-ml-tests</module>
   </modules>