You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/03/13 01:17:01 UTC

[flink] branch master updated: [FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline (#11344)

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f10a23  [FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline (#11344)
6f10a23 is described below

commit 6f10a23e6bca741a9357a91457eb85a8879a40c2
Author: Hequn Cheng <ch...@gmail.com>
AuthorDate: Fri Mar 13 09:16:35 2020 +0800

    [FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline (#11344)
---
 flink-ml-parent/flink-ml-lib/pom.xml               |  17 ++
 .../ml/pipeline/UserDefinedPipelineStages.java     |  54 ++++
 flink-python/pyflink/ml/api/__init__.py            |   6 +-
 flink-python/pyflink/ml/api/base.py                | 275 +++++++++++++++++++++
 flink-python/pyflink/ml/api/param/base.py          |   9 +-
 flink-python/pyflink/ml/lib/param/colname.py       |  17 ++
 flink-python/pyflink/ml/tests/test_pipeline.py     | 150 +++++++++++
 .../pyflink/ml/tests/test_pipeline_it_case.py      | 171 +++++++++++++
 .../pyflink/ml/tests/test_pipeline_stage.py        |  89 +++++++
 9 files changed, 783 insertions(+), 5 deletions(-)

diff --git a/flink-ml-parent/flink-ml-lib/pom.xml b/flink-ml-parent/flink-ml-lib/pom.xml
index df2930f..eb49b32 100644
--- a/flink-ml-parent/flink-ml-lib/pom.xml
+++ b/flink-ml-parent/flink-ml-lib/pom.xml
@@ -66,4 +66,21 @@ under the License.
 			<version>1.1.2</version>
 		</dependency>
 	</dependencies>
+
+	<build>
+		<plugins>
+			<!-- Because PyFlink uses it in tests -->
+			<plugin>
+				<groupId>org.apache.maven.plugins</groupId>
+				<artifactId>maven-jar-plugin</artifactId>
+				<executions>
+					<execution>
+						<goals>
+							<goal>test-jar</goal>
+						</goals>
+					</execution>
+				</executions>
+			</plugin>
+		</plugins>
+	</build>
 </project>
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java
new file mode 100644
index 0000000..7c5fd98
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.pipeline;
+
+import org.apache.flink.ml.api.core.Transformer;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.params.shared.colname.HasSelectedCols;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+
+/**
+ * Util class for testing {@link org.apache.flink.ml.api.core.PipelineStage}.
+ */
+public class UserDefinedPipelineStages {
+
+	/**
+	 * A {@link Transformer} which is used to perform column selection.
+	 */
+	public static class SelectColumnTransformer implements
+		Transformer<SelectColumnTransformer>, HasSelectedCols<SelectColumnTransformer> {
+
+		private Params params;
+
+		public SelectColumnTransformer() {
+			this.params = new Params();
+		}
+
+		@Override
+		public Table transform(TableEnvironment tEnv, Table input) {
+			return input.select(String.join(", ", this.getSelectedCols()));
+		}
+
+		@Override
+		public Params getParams() {
+			return params;
+		}
+	}
+}
diff --git a/flink-python/pyflink/ml/api/__init__.py b/flink-python/pyflink/ml/api/__init__.py
index faca34a..90b3eb8 100644
--- a/flink-python/pyflink/ml/api/__init__.py
+++ b/flink-python/pyflink/ml/api/__init__.py
@@ -18,7 +18,11 @@
 
 from pyflink.ml.api.ml_environment import MLEnvironment
 from pyflink.ml.api.ml_environment_factory import MLEnvironmentFactory
+from pyflink.ml.api.base import Transformer, Estimator, Model, Pipeline, \
+    PipelineStage, JavaTransformer, JavaEstimator, JavaModel
+
 
 __all__ = [
-    "MLEnvironment", "MLEnvironmentFactory"
+    "MLEnvironment", "MLEnvironmentFactory", "Transformer", "Estimator",  "Model",
+    "Pipeline", "PipelineStage", "JavaTransformer", "JavaEstimator", "JavaModel"
 ]
diff --git a/flink-python/pyflink/ml/api/base.py b/flink-python/pyflink/ml/api/base.py
new file mode 100644
index 0000000..8888df3
--- /dev/null
+++ b/flink-python/pyflink/ml/api/base.py
@@ -0,0 +1,275 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import re
+
+from abc import ABCMeta, abstractmethod
+
+from pyflink.table.table_environment import TableEnvironment
+from pyflink.table.table import Table
+from pyflink.ml.api.param import WithParams, Params
+from py4j.java_gateway import get_field
+
+
+class PipelineStage(WithParams):
+    """
+    Base class for a stage in a pipeline. The interface is only a concept, and does not have any
+    actual functionality. Its subclasses must be either Estimator or Transformer. No other classes
+    should inherit this interface directly.
+
+    Each pipeline stage is with parameters, and requires a public empty constructor for
+    restoration in Pipeline.
+    """
+
+    def __init__(self, params=None):
+        if params is None:
+            self._params = Params()
+        else:
+            self._params = params
+
+    def get_params(self) -> Params:
+        return self._params
+
+    def _convert_params_to_java(self, j_pipeline_stage):
+        for param in self._params._param_map:
+            java_param = self._make_java_param(j_pipeline_stage, param)
+            java_value = self._make_java_value(self._params._param_map[param])
+            j_pipeline_stage.set(java_param, java_value)
+
+    @staticmethod
+    def _make_java_param(j_pipeline_stage, param):
+        # camel case to snake case
+        name = re.sub(r'(?<!^)(?=[A-Z])', '_', param.name).upper()
+        return get_field(j_pipeline_stage, name)
+
+    @staticmethod
+    def _make_java_value(obj):
+        """ Convert Python object into Java """
+        if isinstance(obj, list):
+            obj = [PipelineStage._make_java_value(x) for x in obj]
+        return obj
+
+    def to_json(self) -> str:
+        return self.get_params().to_json()
+
+    def load_json(self, json: str) -> None:
+        self.get_params().load_json(json)
+
+
+class Transformer(PipelineStage):
+    """
+    A transformer is a PipelineStage that transforms an input Table to a result Table.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+        """
+        Applies the transformer on the input table, and returns the result table.
+
+        :param table_env: the table environment to which the input table is bound.
+        :param table: the table to be transformed
+        :returns: the transformed table
+        """
+        raise NotImplementedError()
+
+
+class JavaTransformer(Transformer):
+    """
+    Base class for Transformer that wrap Java implementations. Subclasses should
+    ensure they have the transformer Java object available as j_obj.
+    """
+
+    def __init__(self, j_obj):
+        super().__init__()
+        self._j_obj = j_obj
+
+    def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+        """
+        Applies the transformer on the input table, and returns the result table.
+
+        :param table_env: the table environment to which the input table is bound.
+        :param table: the table to be transformed
+        :returns: the transformed table
+        """
+        self._convert_params_to_java(self._j_obj)
+        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))
+
+
+class Model(Transformer):
+    """
+    Abstract class for models that are fitted by estimators.
+
+    A model is an ordinary Transformer except how it is created. While ordinary transformers
+    are defined by specifying the parameters directly, a model is usually generated by an Estimator
+    when Estimator.fit(table_env, table) is invoked.
+    """
+
+    __metaclass__ = ABCMeta
+
+
+class JavaModel(JavaTransformer, Model):
+    """
+    Base class for JavaTransformer that wrap Java implementations.
+    Subclasses should ensure they have the model Java object available as j_obj.
+    """
+
+
+class Estimator(PipelineStage):
+    """
+    Estimators are PipelineStages responsible for training and generating machine learning models.
+
+    The implementations are expected to take an input table as training samples and generate a
+    Model which fits these samples.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def fit(self, table_env: TableEnvironment, table: Table) -> Model:
+        """
+        Train and produce a Model which fits the records in the given Table.
+
+        :param table_env: the table environment to which the input table is bound.
+        :param table: the table with records to train the Model.
+        :returns: a model trained to fit on the given Table.
+        """
+        raise NotImplementedError()
+
+
+class JavaEstimator(Estimator):
+    """
+    Base class for Estimator that wrap Java implementations.
+    Subclasses should ensure they have the estimator Java object available as j_obj.
+    """
+
+    def __init__(self, j_obj):
+        super().__init__()
+        self._j_obj = j_obj
+
+    def fit(self, table_env: TableEnvironment, table: Table) -> JavaModel:
+        """
+        Train and produce a Model which fits the records in the given Table.
+
+        :param table_env: the table environment to which the input table is bound.
+        :param table: the table with records to train the Model.
+        :returns: a model trained to fit on the given Table.
+        """
+        self._convert_params_to_java(self._j_obj)
+        return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table))
+
+
+class Pipeline(Estimator, Model, Transformer):
+    """
+    A pipeline is a linear workflow which chains Estimators and Transformers to
+    execute an algorithm.
+
+    A pipeline itself can either act as an Estimator or a Transformer, depending on the stages it
+    includes. More specifically:
+
+
+    If a Pipeline has an Estimator, one needs to call `Pipeline.fit(TableEnvironment, Table)`
+    before use the pipeline as a Transformer. In this case the Pipeline is an Estimator and
+    can produce a Pipeline as a `Model`.
+
+    If a Pipeline has noEstimator, it is a Transformer and can be applied to a Table directly.
+    In this case, `Pipeline#fit(TableEnvironment, Table)` will simply return the pipeline itself.
+
+
+    In addition, a pipeline can also be used as a PipelineStage in another pipeline, just like an
+    ordinaryEstimator or Transformer as describe above.
+    """
+
+    def __init__(self, stages=None, pipeline_json=None):
+        super().__init__()
+        self.stages = []
+        self.last_estimator_index = -1
+        if stages is not None:
+            for stage in stages:
+                self.append_stage(stage)
+        if pipeline_json is not None:
+            self.load_json(pipeline_json)
+
+    def need_fit(self):
+        return self.last_estimator_index >= 0
+
+    @staticmethod
+    def _is_stage_need_fit(stage):
+        return (isinstance(stage, Pipeline) and stage.need_fit()) or \
+               ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator))
+
+    def get_stages(self) -> tuple:
+        # make it immutable by changing to tuple
+        return tuple(self.stages)
+
+    def append_stage(self, stage: PipelineStage) -> 'Pipeline':
+        if self._is_stage_need_fit(stage):
+            self.last_estimator_index = len(self.stages)
+        elif not isinstance(stage, Transformer):
+            raise RuntimeError("All PipelineStages should be Estimator or Transformer!")
+        self.stages.append(stage)
+        return self
+
+    def fit(self, t_env: TableEnvironment, input: Table) -> 'Pipeline':
+        """
+        Train the pipeline to fit on the records in the given Table.
+
+        :param t_env: the table environment to which the input table is bound.
+        :param input: the table with records to train the Pipeline.
+        :returns: a pipeline with same stages as this Pipeline except all Estimators \
+        replaced with their corresponding Models.
+        """
+        transform_stages = []
+        for i in range(0, len(self.stages)):
+            s = self.stages[i]
+            if i <= self.last_estimator_index:
+                need_fit = self._is_stage_need_fit(s)
+                if need_fit:
+                    t = s.fit(t_env, input)
+                else:
+                    t = s
+                transform_stages.append(t)
+                input = t.transform(t_env, input)
+            else:
+                transform_stages.append(s)
+        return Pipeline(transform_stages)
+
+    def transform(self, t_env: TableEnvironment, input: Table) -> Table:
+        """
+        Generate a result table by applying all the stages in this pipeline to
+        the input table in order.
+
+        :param t_env: the table environment to which the input table is bound.
+        :param input: the table to be transformed.
+        :returns: a result table with all the stages applied to the input tables in order.
+        """
+        if self.need_fit():
+            raise RuntimeError("Pipeline contains Estimator, need to fit first.")
+        for s in self.stages:
+            input = s.transform(t_env, input)
+        return input
+
+    def to_json(self) -> str:
+        import jsonpickle
+        return str(jsonpickle.encode(self, keys=True))
+
+    def load_json(self, json: str) -> None:
+        import jsonpickle
+        pipeline = jsonpickle.decode(json, keys=True)
+        for stage in pipeline.get_stages():
+            self.append_stage(stage)
diff --git a/flink-python/pyflink/ml/api/param/base.py b/flink-python/pyflink/ml/api/param/base.py
index b784bc7..5188f2c 100644
--- a/flink-python/pyflink/ml/api/param/base.py
+++ b/flink-python/pyflink/ml/api/param/base.py
@@ -164,17 +164,16 @@ class Params(Generic[V]):
         import jsonpickle
         return str(jsonpickle.encode(self._param_map, keys=True))
 
-    def load_json(self, json: str) -> 'Params':
+    def load_json(self, json: str) -> None:
         """
         Restores the parameters from the given json. The parameters should be exactly
         the same with the one who was serialized to the input json after the restoration.
 
         :param json: the json String to restore from.
-        :return: the Params.
+        :return: None.
         """
         import jsonpickle
         self._param_map.update(jsonpickle.decode(json, keys=True))
-        return self
 
     @staticmethod
     def from_json(json) -> 'Params':
@@ -184,7 +183,9 @@ class Params(Generic[V]):
         :param json: the json string to load.
         :return: the `Params` loaded from the json string.
         """
-        return Params().load_json(json)
+        ret = Params()
+        ret.load_json(json)
+        return ret
 
     def merge(self, other_params: 'Params') -> 'Params':
         """
diff --git a/flink-python/pyflink/ml/lib/param/colname.py b/flink-python/pyflink/ml/lib/param/colname.py
index 034cd4c..581551b 100644
--- a/flink-python/pyflink/ml/lib/param/colname.py
+++ b/flink-python/pyflink/ml/lib/param/colname.py
@@ -53,3 +53,20 @@ class HasOutputCol(WithParams):
 
     def get_output_col(self) -> str:
         return super().get(self.output_col)
+
+
+class HasPredictionCol(WithParams):
+    """
+    An interface for classes with a parameter specifying the column name of the prediction.
+    """
+    prediction_col = ParamInfo(
+        "predictionCol",
+        "Column name of prediction.",
+        is_optional=False,
+        type_converter=TypeConverters.to_string)
+
+    def set_prediction_col(self, v: str) -> 'HasPredictionCol':
+        return super().set(self.prediction_col, v)
+
+    def get_prediction_col(self) -> str:
+        return super().get(self.prediction_col)
diff --git a/flink-python/pyflink/ml/tests/test_pipeline.py b/flink-python/pyflink/ml/tests/test_pipeline.py
new file mode 100644
index 0000000..31c3068
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline.py
@@ -0,0 +1,150 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import unittest
+from pyflink.ml.api import JavaTransformer, Transformer, Estimator, Model, \
+    Pipeline, JavaEstimator, JavaModel
+from pyflink.ml.api.param.base import WithParams, ParamInfo, TypeConverters
+from pyflink import keyword
+
+
+class PipelineTest(unittest.TestCase):
+
+    @staticmethod
+    def describe_pipeline(pipeline):
+        res = [stage.get_desc() for stage in pipeline.get_stages()]
+        return "_".join(res)
+
+    def test_construct_pipeline(self):
+        pipeline1 = Pipeline()
+        pipeline1.append_stage(MockTransformer(self_desc="a1"))
+        pipeline1.append_stage(MockJavaTransformer(self_desc="ja1"))
+
+        pipeline2 = Pipeline()
+        pipeline2.append_stage(MockTransformer(self_desc="a2"))
+        pipeline2.append_stage(MockJavaTransformer(self_desc="ja2"))
+
+        pipeline3 = Pipeline(pipeline1.get_stages(), pipeline2.to_json())
+        self.assertEqual("a1_ja1_a2_ja2", PipelineTest.describe_pipeline(pipeline3))
+
+    def test_pipeline_behavior(self):
+        pipeline = Pipeline()
+        pipeline.append_stage(MockTransformer(self_desc="a"))
+        pipeline.append_stage(MockJavaTransformer(self_desc="ja"))
+        pipeline.append_stage(MockEstimator(self_desc="b"))
+        pipeline.append_stage(MockJavaEstimator(self_desc="jb"))
+        pipeline.append_stage(MockEstimator(self_desc="c"))
+        pipeline.append_stage(MockTransformer(self_desc="d"))
+        self.assertEqual("a_ja_b_jb_c_d", PipelineTest.describe_pipeline(pipeline))
+
+        pipeline_model = pipeline.fit(None, None)
+        self.assertEqual("a_ja_mb_mjb_mc_d", PipelineTest.describe_pipeline(pipeline_model))
+
+    def test_pipeline_restore(self):
+        pipeline = Pipeline()
+        pipeline.append_stage(MockTransformer(self_desc="a"))
+        pipeline.append_stage(MockJavaTransformer(self_desc="ja"))
+        pipeline.append_stage(MockEstimator(self_desc="b"))
+        pipeline.append_stage(MockJavaEstimator(self_desc="jb"))
+        pipeline.append_stage(MockEstimator(self_desc="c"))
+        pipeline.append_stage(MockTransformer(self_desc="d"))
+
+        pipeline_new = Pipeline()
+        pipeline_new.load_json(pipeline.to_json())
+        self.assertEqual("a_ja_b_jb_c_d", PipelineTest.describe_pipeline(pipeline_new))
+
+        pipeline_model = pipeline_new.fit(None, None)
+        self.assertEqual("a_ja_mb_mjb_mc_d", PipelineTest.describe_pipeline(pipeline_model))
+
+
+class SelfDescribe(WithParams):
+    self_desc = ParamInfo("selfDesc", "selfDesc", type_converter=TypeConverters.to_string)
+
+    def set_desc(self, v):
+        return super().set(self.self_desc, v)
+
+    def get_desc(self):
+        return super().get(self.self_desc)
+
+
+class MockTransformer(Transformer, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__()
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def transform(self, table_env, table):
+        return table
+
+
+class MockEstimator(Estimator, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__()
+        self._self_desc = self_desc
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def fit(self, table_env, table):
+        return MockModel(self_desc="m" + self._self_desc)
+
+
+class MockModel(Model, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__()
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def transform(self, table_env, table):
+        return table
+
+
+class MockJavaTransformer(JavaTransformer, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__(None)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def transform(self, table_env, table):
+        return table
+
+
+class MockJavaEstimator(JavaEstimator, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__(None)
+        self._self_desc = self_desc
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def fit(self, table_env, table):
+        return MockJavaModel(self_desc="m" + self._self_desc)
+
+
+class MockJavaModel(JavaModel, SelfDescribe):
+    @keyword
+    def __init__(self, *, self_desc=None):
+        super().__init__(None)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def transform(self, table_env, table):
+        return table
diff --git a/flink-python/pyflink/ml/tests/test_pipeline_it_case.py b/flink-python/pyflink/ml/tests/test_pipeline_it_case.py
new file mode 100644
index 0000000..655d20c
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline_it_case.py
@@ -0,0 +1,171 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.table.types import DataTypes
+from pyflink.testing.test_case_utils import MLTestCase
+
+from pyflink.ml.api import JavaTransformer, Transformer, Estimator, Model, \
+    MLEnvironmentFactory, Pipeline
+from pyflink.ml.api.param import WithParams, ParamInfo, TypeConverters
+from pyflink.ml.lib.param.colname import HasSelectedCols,\
+    HasPredictionCol, HasOutputCol
+from pyflink import keyword
+from pyflink.testing import source_sink_utils
+from pyflink.java_gateway import get_gateway
+
+
+class HasVectorCol(WithParams):
+    """
+    Trait for parameter vectorColName.
+    """
+    vector_col = ParamInfo(
+        "vectorCol",
+        "Name of a vector column",
+        is_optional=False,
+        type_converter=TypeConverters.to_string)
+
+    def set_vector_col(self, v: str) -> 'HasVectorCol':
+        return super().set(self.vector_col, v)
+
+    def get_vector_col(self) -> str:
+        return super().get(self.vector_col)
+
+
+class WrapperTransformer(JavaTransformer, HasSelectedCols):
+    """
+    A Transformer wrappers Java Transformer.
+    """
+    @keyword
+    def __init__(self, *, selected_cols=None):
+        _j_obj = get_gateway().jvm.org.apache.flink.ml.pipeline.\
+            UserDefinedPipelineStages.SelectColumnTransformer()
+        super().__init__(_j_obj)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+
+class PythonAddTransformer(Transformer, HasSelectedCols, HasOutputCol):
+    """
+    A Transformer which is implemented with Python. Output a column
+    contains the sum of all columns.
+    """
+    @keyword
+    def __init__(self, *, selected_cols=None, output_col=None):
+        super().__init__()
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    def transform(self, table_env, table):
+        input_columns = self.get_selected_cols()
+        expr = "+".join(input_columns)
+        expr = expr + " as " + self.get_output_col()
+        return table.add_columns(expr)
+
+
+class PythonEstimator(Estimator, HasVectorCol, HasPredictionCol):
+
+    def __init__(self):
+        super().__init__()
+
+    def fit(self, table_env, table):
+        return PythonModel(
+            table_env,
+            table.select("max(features) as max_sum"),
+            self.get_prediction_col())
+
+
+class PythonModel(Model):
+
+    def __init__(self, table_env, model_data_table, output_col_name):
+        self._model_data_table = model_data_table
+        self._output_col_name = output_col_name
+        self.max_sum = 0
+        self.load_model(table_env)
+
+    def load_model(self, table_env):
+        """
+        Train the model to get the max_sum value which is used to predict data.
+        """
+        table_sink = source_sink_utils.TestRetractSink(["max_sum"], [DataTypes.BIGINT()])
+        table_env.register_table_sink("Model_Results", table_sink)
+        self._model_data_table.insert_into("Model_Results")
+        table_env.execute("load model")
+        actual = source_sink_utils.results()
+        self.max_sum = actual.apply(0)
+
+    def transform(self, table_env, table):
+        """
+        Use max_sum to predict input. Return turn if input value is bigger than max_sum
+        """
+        return table\
+            .add_columns("features > {} as {}".format(self.max_sum, self._output_col_name))\
+            .select("{}".format(self._output_col_name))
+
+
+class PythonPipelineTest(MLTestCase):
+
+    def test_java_transformer(self):
+        t_env = MLEnvironmentFactory().get_default().get_stream_table_environment()
+
+        table_sink = source_sink_utils.TestAppendSink(
+            ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
+        t_env.register_table_sink("TransformerResults", table_sink)
+
+        source_table = t_env.from_elements([(1, 2, 3, 4), (4, 3, 2, 1)], ['a', 'b', 'c', 'd'])
+        transformer = WrapperTransformer(selected_cols=["a", "b"])
+        transformer\
+            .transform(t_env, source_table)\
+            .insert_into("TransformerResults")
+
+        # execute
+        t_env.execute('JavaPipelineITCase')
+        actual = source_sink_utils.results()
+        self.assert_equals(actual, ["1,2", "4,3"])
+
+    def test_pipeline(self):
+        t_env = MLEnvironmentFactory().get_default().get_stream_table_environment()
+        train_table = t_env.from_elements(
+            [(1, 2), (1, 4), (1, 0), (10, 2), (10, 4), (10, 0)], ['a', 'b'])
+        serving_table = t_env.from_elements([(0, 0), (12, 3)], ['a', 'b'])
+
+        table_sink = source_sink_utils.TestAppendSink(
+            ['predict_result'],
+            [DataTypes.BOOLEAN()])
+        t_env.register_table_sink("PredictResults", table_sink)
+
+        # transformer, output features column which is the sum of a and b.
+        transformer = PythonAddTransformer(selected_cols=["a", "b"], output_col="features")
+
+        # estimator
+        estimator = PythonEstimator()\
+            .set_vector_col("features")\
+            .set_prediction_col("predict_result")
+
+        # pipeline
+        pipeline = Pipeline().append_stage(transformer).append_stage(estimator)
+        pipeline\
+            .fit(t_env, train_table)\
+            .transform(t_env, serving_table)\
+            .insert_into('PredictResults')
+        # execute
+        t_env.execute('PipelineITCase')
+
+        actual = source_sink_utils.results()
+        # the first input is false since 0 + 0 is smaller than the max_sum 14.
+        # the second input is true since 12 + 3 is bigger than the max_sum 14.
+        self.assert_equals(actual, ["false", "true"])
diff --git a/flink-python/pyflink/ml/tests/test_pipeline_stage.py b/flink-python/pyflink/ml/tests/test_pipeline_stage.py
new file mode 100644
index 0000000..cce53f2
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline_stage.py
@@ -0,0 +1,89 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from abc import ABC, abstractmethod
+from pyflink.ml.api import Transformer, Estimator, MLEnvironmentFactory
+import unittest
+
+
+class PipelineStageTestBase(ABC):
+    """
+    The base class for testing the base implementation of pipeline stages, i.e. Estimators
+    and Transformers. This class is package private because we do not expect extension outside
+    of the package.
+    """
+
+    @abstractmethod
+    def create_pipeline_stage(self):
+        pass
+
+
+class TransformerBaseTest(PipelineStageTestBase, unittest.TestCase):
+    """
+    Test for TransformerBase.
+    """
+    def test_fit_table(self):
+        id = MLEnvironmentFactory.get_new_ml_environment_id()
+        env = MLEnvironmentFactory.get(id)
+        table = env.get_stream_table_environment().from_elements([(1, 2, 3)])
+        transformer = self.create_pipeline_stage()
+        transformer.transform(env.get_stream_table_environment(), table)
+        self.assertTrue(transformer.transformed)
+
+    def create_pipeline_stage(self):
+        return self.FakeTransFormer()
+
+    class FakeTransFormer(Transformer):
+        """
+        This fake transformer simply record which transform method is invoked.
+        """
+
+        def __init__(self):
+            self.transformed = False
+
+        def transform(self, table_env, table):
+            self.transformed = True
+            return table
+
+
+class EstimatorBaseTest(PipelineStageTestBase, unittest.TestCase):
+    """
+    Test for EstimatorBase.
+    """
+    def test_fit_table(self):
+        id = MLEnvironmentFactory.get_new_ml_environment_id()
+        env = MLEnvironmentFactory.get(id)
+        table = env.get_stream_table_environment().from_elements([(1, 2, 3)])
+        estimator = self.create_pipeline_stage()
+        estimator.fit(env.get_stream_table_environment(), table)
+        self.assertTrue(estimator.fitted)
+
+    def create_pipeline_stage(self):
+        return self.FakeEstimator()
+
+    class FakeEstimator(Estimator):
+        """
+        This fake estimator simply record which fit method is invoked.
+        """
+
+        def __init__(self):
+            self.fitted = False
+
+        def fit(self, table_env, table):
+            self.fitted = True
+            return None