You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/03/13 01:17:01 UTC
[flink] branch master updated: [FLINK-16250][python][ml] Add
interfaces for PipelineStage and Pipeline (#11344)
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 6f10a23 [FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline (#11344)
6f10a23 is described below
commit 6f10a23e6bca741a9357a91457eb85a8879a40c2
Author: Hequn Cheng <ch...@gmail.com>
AuthorDate: Fri Mar 13 09:16:35 2020 +0800
[FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline (#11344)
---
flink-ml-parent/flink-ml-lib/pom.xml | 17 ++
.../ml/pipeline/UserDefinedPipelineStages.java | 54 ++++
flink-python/pyflink/ml/api/__init__.py | 6 +-
flink-python/pyflink/ml/api/base.py | 275 +++++++++++++++++++++
flink-python/pyflink/ml/api/param/base.py | 9 +-
flink-python/pyflink/ml/lib/param/colname.py | 17 ++
flink-python/pyflink/ml/tests/test_pipeline.py | 150 +++++++++++
.../pyflink/ml/tests/test_pipeline_it_case.py | 171 +++++++++++++
.../pyflink/ml/tests/test_pipeline_stage.py | 89 +++++++
9 files changed, 783 insertions(+), 5 deletions(-)
diff --git a/flink-ml-parent/flink-ml-lib/pom.xml b/flink-ml-parent/flink-ml-lib/pom.xml
index df2930f..eb49b32 100644
--- a/flink-ml-parent/flink-ml-lib/pom.xml
+++ b/flink-ml-parent/flink-ml-lib/pom.xml
@@ -66,4 +66,21 @@ under the License.
<version>1.1.2</version>
</dependency>
</dependencies>
+
+ <build>
+ <plugins>
+ <!-- Because PyFlink uses it in tests -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
</project>
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java
new file mode 100644
index 0000000..7c5fd98
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/pipeline/UserDefinedPipelineStages.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.pipeline;
+
+import org.apache.flink.ml.api.core.Transformer;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.params.shared.colname.HasSelectedCols;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+
+/**
+ * Util class for testing {@link org.apache.flink.ml.api.core.PipelineStage}.
+ */
+public class UserDefinedPipelineStages {
+
+ /**
+ * A {@link Transformer} which is used to perform column selection.
+ */
+ public static class SelectColumnTransformer implements
+ Transformer<SelectColumnTransformer>, HasSelectedCols<SelectColumnTransformer> {
+
+ private Params params;
+
+ public SelectColumnTransformer() {
+ this.params = new Params();
+ }
+
+ @Override
+ public Table transform(TableEnvironment tEnv, Table input) {
+ return input.select(String.join(", ", this.getSelectedCols()));
+ }
+
+ @Override
+ public Params getParams() {
+ return params;
+ }
+ }
+}
diff --git a/flink-python/pyflink/ml/api/__init__.py b/flink-python/pyflink/ml/api/__init__.py
index faca34a..90b3eb8 100644
--- a/flink-python/pyflink/ml/api/__init__.py
+++ b/flink-python/pyflink/ml/api/__init__.py
@@ -18,7 +18,11 @@
from pyflink.ml.api.ml_environment import MLEnvironment
from pyflink.ml.api.ml_environment_factory import MLEnvironmentFactory
+from pyflink.ml.api.base import Transformer, Estimator, Model, Pipeline, \
+ PipelineStage, JavaTransformer, JavaEstimator, JavaModel
+
__all__ = [
- "MLEnvironment", "MLEnvironmentFactory"
+ "MLEnvironment", "MLEnvironmentFactory", "Transformer", "Estimator", "Model",
+ "Pipeline", "PipelineStage", "JavaTransformer", "JavaEstimator", "JavaModel"
]
diff --git a/flink-python/pyflink/ml/api/base.py b/flink-python/pyflink/ml/api/base.py
new file mode 100644
index 0000000..8888df3
--- /dev/null
+++ b/flink-python/pyflink/ml/api/base.py
@@ -0,0 +1,275 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import re
+
+from abc import ABCMeta, abstractmethod
+
+from pyflink.table.table_environment import TableEnvironment
+from pyflink.table.table import Table
+from pyflink.ml.api.param import WithParams, Params
+from py4j.java_gateway import get_field
+
+
+class PipelineStage(WithParams):
+ """
+ Base class for a stage in a pipeline. The interface is only a concept, and does not have any
+ actual functionality. Its subclasses must be either Estimator or Transformer. No other classes
+ should inherit this interface directly.
+
+ Each pipeline stage is with parameters, and requires a public empty constructor for
+ restoration in Pipeline.
+ """
+
+ def __init__(self, params=None):
+ if params is None:
+ self._params = Params()
+ else:
+ self._params = params
+
+ def get_params(self) -> Params:
+ return self._params
+
+ def _convert_params_to_java(self, j_pipeline_stage):
+ for param in self._params._param_map:
+ java_param = self._make_java_param(j_pipeline_stage, param)
+ java_value = self._make_java_value(self._params._param_map[param])
+ j_pipeline_stage.set(java_param, java_value)
+
+ @staticmethod
+ def _make_java_param(j_pipeline_stage, param):
+ # camel case to snake case
+ name = re.sub(r'(?<!^)(?=[A-Z])', '_', param.name).upper()
+ return get_field(j_pipeline_stage, name)
+
+ @staticmethod
+ def _make_java_value(obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, list):
+ obj = [PipelineStage._make_java_value(x) for x in obj]
+ return obj
+
+ def to_json(self) -> str:
+ return self.get_params().to_json()
+
+ def load_json(self, json: str) -> None:
+ self.get_params().load_json(json)
+
+
+class Transformer(PipelineStage):
+ """
+ A transformer is a PipelineStage that transforms an input Table to a result Table.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+ """
+ Applies the transformer on the input table, and returns the result table.
+
+ :param table_env: the table environment to which the input table is bound.
+ :param table: the table to be transformed
+ :returns: the transformed table
+ """
+ raise NotImplementedError()
+
+
+class JavaTransformer(Transformer):
+ """
+ Base class for Transformer that wrap Java implementations. Subclasses should
+ ensure they have the transformer Java object available as j_obj.
+ """
+
+ def __init__(self, j_obj):
+ super().__init__()
+ self._j_obj = j_obj
+
+ def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+ """
+ Applies the transformer on the input table, and returns the result table.
+
+ :param table_env: the table environment to which the input table is bound.
+ :param table: the table to be transformed
+ :returns: the transformed table
+ """
+ self._convert_params_to_java(self._j_obj)
+ return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))
+
+
+class Model(Transformer):
+ """
+ Abstract class for models that are fitted by estimators.
+
+ A model is an ordinary Transformer except how it is created. While ordinary transformers
+ are defined by specifying the parameters directly, a model is usually generated by an Estimator
+ when Estimator.fit(table_env, table) is invoked.
+ """
+
+ __metaclass__ = ABCMeta
+
+
+class JavaModel(JavaTransformer, Model):
+ """
+ Base class for JavaTransformer that wrap Java implementations.
+ Subclasses should ensure they have the model Java object available as j_obj.
+ """
+
+
+class Estimator(PipelineStage):
+ """
+ Estimators are PipelineStages responsible for training and generating machine learning models.
+
+ The implementations are expected to take an input table as training samples and generate a
+ Model which fits these samples.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def fit(self, table_env: TableEnvironment, table: Table) -> Model:
+ """
+ Train and produce a Model which fits the records in the given Table.
+
+ :param table_env: the table environment to which the input table is bound.
+ :param table: the table with records to train the Model.
+ :returns: a model trained to fit on the given Table.
+ """
+ raise NotImplementedError()
+
+
+class JavaEstimator(Estimator):
+ """
+ Base class for Estimator that wrap Java implementations.
+ Subclasses should ensure they have the estimator Java object available as j_obj.
+ """
+
+ def __init__(self, j_obj):
+ super().__init__()
+ self._j_obj = j_obj
+
+ def fit(self, table_env: TableEnvironment, table: Table) -> JavaModel:
+ """
+ Train and produce a Model which fits the records in the given Table.
+
+ :param table_env: the table environment to which the input table is bound.
+ :param table: the table with records to train the Model.
+ :returns: a model trained to fit on the given Table.
+ """
+ self._convert_params_to_java(self._j_obj)
+ return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table))
+
+
+class Pipeline(Estimator, Model, Transformer):
+ """
+ A pipeline is a linear workflow which chains Estimators and Transformers to
+ execute an algorithm.
+
+ A pipeline itself can either act as an Estimator or a Transformer, depending on the stages it
+ includes. More specifically:
+
+
+ If a Pipeline has an Estimator, one needs to call `Pipeline.fit(TableEnvironment, Table)`
+ before use the pipeline as a Transformer. In this case the Pipeline is an Estimator and
+ can produce a Pipeline as a `Model`.
+
+ If a Pipeline has noEstimator, it is a Transformer and can be applied to a Table directly.
+ In this case, `Pipeline#fit(TableEnvironment, Table)` will simply return the pipeline itself.
+
+
+ In addition, a pipeline can also be used as a PipelineStage in another pipeline, just like an
+ ordinaryEstimator or Transformer as describe above.
+ """
+
+ def __init__(self, stages=None, pipeline_json=None):
+ super().__init__()
+ self.stages = []
+ self.last_estimator_index = -1
+ if stages is not None:
+ for stage in stages:
+ self.append_stage(stage)
+ if pipeline_json is not None:
+ self.load_json(pipeline_json)
+
+ def need_fit(self):
+ return self.last_estimator_index >= 0
+
+ @staticmethod
+ def _is_stage_need_fit(stage):
+ return (isinstance(stage, Pipeline) and stage.need_fit()) or \
+ ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator))
+
+ def get_stages(self) -> tuple:
+ # make it immutable by changing to tuple
+ return tuple(self.stages)
+
+ def append_stage(self, stage: PipelineStage) -> 'Pipeline':
+ if self._is_stage_need_fit(stage):
+ self.last_estimator_index = len(self.stages)
+ elif not isinstance(stage, Transformer):
+ raise RuntimeError("All PipelineStages should be Estimator or Transformer!")
+ self.stages.append(stage)
+ return self
+
+ def fit(self, t_env: TableEnvironment, input: Table) -> 'Pipeline':
+ """
+ Train the pipeline to fit on the records in the given Table.
+
+ :param t_env: the table environment to which the input table is bound.
+ :param input: the table with records to train the Pipeline.
+ :returns: a pipeline with same stages as this Pipeline except all Estimators \
+ replaced with their corresponding Models.
+ """
+ transform_stages = []
+ for i in range(0, len(self.stages)):
+ s = self.stages[i]
+ if i <= self.last_estimator_index:
+ need_fit = self._is_stage_need_fit(s)
+ if need_fit:
+ t = s.fit(t_env, input)
+ else:
+ t = s
+ transform_stages.append(t)
+ input = t.transform(t_env, input)
+ else:
+ transform_stages.append(s)
+ return Pipeline(transform_stages)
+
+ def transform(self, t_env: TableEnvironment, input: Table) -> Table:
+ """
+ Generate a result table by applying all the stages in this pipeline to
+ the input table in order.
+
+ :param t_env: the table environment to which the input table is bound.
+ :param input: the table to be transformed.
+ :returns: a result table with all the stages applied to the input tables in order.
+ """
+ if self.need_fit():
+ raise RuntimeError("Pipeline contains Estimator, need to fit first.")
+ for s in self.stages:
+ input = s.transform(t_env, input)
+ return input
+
+ def to_json(self) -> str:
+ import jsonpickle
+ return str(jsonpickle.encode(self, keys=True))
+
+ def load_json(self, json: str) -> None:
+ import jsonpickle
+ pipeline = jsonpickle.decode(json, keys=True)
+ for stage in pipeline.get_stages():
+ self.append_stage(stage)
diff --git a/flink-python/pyflink/ml/api/param/base.py b/flink-python/pyflink/ml/api/param/base.py
index b784bc7..5188f2c 100644
--- a/flink-python/pyflink/ml/api/param/base.py
+++ b/flink-python/pyflink/ml/api/param/base.py
@@ -164,17 +164,16 @@ class Params(Generic[V]):
import jsonpickle
return str(jsonpickle.encode(self._param_map, keys=True))
- def load_json(self, json: str) -> 'Params':
+ def load_json(self, json: str) -> None:
"""
Restores the parameters from the given json. The parameters should be exactly
the same with the one who was serialized to the input json after the restoration.
:param json: the json String to restore from.
- :return: the Params.
+ :return: None.
"""
import jsonpickle
self._param_map.update(jsonpickle.decode(json, keys=True))
- return self
@staticmethod
def from_json(json) -> 'Params':
@@ -184,7 +183,9 @@ class Params(Generic[V]):
:param json: the json string to load.
:return: the `Params` loaded from the json string.
"""
- return Params().load_json(json)
+ ret = Params()
+ ret.load_json(json)
+ return ret
def merge(self, other_params: 'Params') -> 'Params':
"""
diff --git a/flink-python/pyflink/ml/lib/param/colname.py b/flink-python/pyflink/ml/lib/param/colname.py
index 034cd4c..581551b 100644
--- a/flink-python/pyflink/ml/lib/param/colname.py
+++ b/flink-python/pyflink/ml/lib/param/colname.py
@@ -53,3 +53,20 @@ class HasOutputCol(WithParams):
def get_output_col(self) -> str:
return super().get(self.output_col)
+
+
+class HasPredictionCol(WithParams):
+ """
+ An interface for classes with a parameter specifying the column name of the prediction.
+ """
+ prediction_col = ParamInfo(
+ "predictionCol",
+ "Column name of prediction.",
+ is_optional=False,
+ type_converter=TypeConverters.to_string)
+
+ def set_prediction_col(self, v: str) -> 'HasPredictionCol':
+ return super().set(self.prediction_col, v)
+
+ def get_prediction_col(self) -> str:
+ return super().get(self.prediction_col)
diff --git a/flink-python/pyflink/ml/tests/test_pipeline.py b/flink-python/pyflink/ml/tests/test_pipeline.py
new file mode 100644
index 0000000..31c3068
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline.py
@@ -0,0 +1,150 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import unittest
+from pyflink.ml.api import JavaTransformer, Transformer, Estimator, Model, \
+ Pipeline, JavaEstimator, JavaModel
+from pyflink.ml.api.param.base import WithParams, ParamInfo, TypeConverters
+from pyflink import keyword
+
+
+class PipelineTest(unittest.TestCase):
+
+ @staticmethod
+ def describe_pipeline(pipeline):
+ res = [stage.get_desc() for stage in pipeline.get_stages()]
+ return "_".join(res)
+
+ def test_construct_pipeline(self):
+ pipeline1 = Pipeline()
+ pipeline1.append_stage(MockTransformer(self_desc="a1"))
+ pipeline1.append_stage(MockJavaTransformer(self_desc="ja1"))
+
+ pipeline2 = Pipeline()
+ pipeline2.append_stage(MockTransformer(self_desc="a2"))
+ pipeline2.append_stage(MockJavaTransformer(self_desc="ja2"))
+
+ pipeline3 = Pipeline(pipeline1.get_stages(), pipeline2.to_json())
+ self.assertEqual("a1_ja1_a2_ja2", PipelineTest.describe_pipeline(pipeline3))
+
+ def test_pipeline_behavior(self):
+ pipeline = Pipeline()
+ pipeline.append_stage(MockTransformer(self_desc="a"))
+ pipeline.append_stage(MockJavaTransformer(self_desc="ja"))
+ pipeline.append_stage(MockEstimator(self_desc="b"))
+ pipeline.append_stage(MockJavaEstimator(self_desc="jb"))
+ pipeline.append_stage(MockEstimator(self_desc="c"))
+ pipeline.append_stage(MockTransformer(self_desc="d"))
+ self.assertEqual("a_ja_b_jb_c_d", PipelineTest.describe_pipeline(pipeline))
+
+ pipeline_model = pipeline.fit(None, None)
+ self.assertEqual("a_ja_mb_mjb_mc_d", PipelineTest.describe_pipeline(pipeline_model))
+
+ def test_pipeline_restore(self):
+ pipeline = Pipeline()
+ pipeline.append_stage(MockTransformer(self_desc="a"))
+ pipeline.append_stage(MockJavaTransformer(self_desc="ja"))
+ pipeline.append_stage(MockEstimator(self_desc="b"))
+ pipeline.append_stage(MockJavaEstimator(self_desc="jb"))
+ pipeline.append_stage(MockEstimator(self_desc="c"))
+ pipeline.append_stage(MockTransformer(self_desc="d"))
+
+ pipeline_new = Pipeline()
+ pipeline_new.load_json(pipeline.to_json())
+ self.assertEqual("a_ja_b_jb_c_d", PipelineTest.describe_pipeline(pipeline_new))
+
+ pipeline_model = pipeline_new.fit(None, None)
+ self.assertEqual("a_ja_mb_mjb_mc_d", PipelineTest.describe_pipeline(pipeline_model))
+
+
+class SelfDescribe(WithParams):
+ self_desc = ParamInfo("selfDesc", "selfDesc", type_converter=TypeConverters.to_string)
+
+ def set_desc(self, v):
+ return super().set(self.self_desc, v)
+
+ def get_desc(self):
+ return super().get(self.self_desc)
+
+
+class MockTransformer(Transformer, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__()
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def transform(self, table_env, table):
+ return table
+
+
+class MockEstimator(Estimator, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__()
+ self._self_desc = self_desc
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def fit(self, table_env, table):
+ return MockModel(self_desc="m" + self._self_desc)
+
+
+class MockModel(Model, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__()
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def transform(self, table_env, table):
+ return table
+
+
+class MockJavaTransformer(JavaTransformer, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__(None)
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def transform(self, table_env, table):
+ return table
+
+
+class MockJavaEstimator(JavaEstimator, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__(None)
+ self._self_desc = self_desc
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def fit(self, table_env, table):
+ return MockJavaModel(self_desc="m" + self._self_desc)
+
+
+class MockJavaModel(JavaModel, SelfDescribe):
+ @keyword
+ def __init__(self, *, self_desc=None):
+ super().__init__(None)
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def transform(self, table_env, table):
+ return table
diff --git a/flink-python/pyflink/ml/tests/test_pipeline_it_case.py b/flink-python/pyflink/ml/tests/test_pipeline_it_case.py
new file mode 100644
index 0000000..655d20c
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline_it_case.py
@@ -0,0 +1,171 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.table.types import DataTypes
+from pyflink.testing.test_case_utils import MLTestCase
+
+from pyflink.ml.api import JavaTransformer, Transformer, Estimator, Model, \
+ MLEnvironmentFactory, Pipeline
+from pyflink.ml.api.param import WithParams, ParamInfo, TypeConverters
+from pyflink.ml.lib.param.colname import HasSelectedCols,\
+ HasPredictionCol, HasOutputCol
+from pyflink import keyword
+from pyflink.testing import source_sink_utils
+from pyflink.java_gateway import get_gateway
+
+
+class HasVectorCol(WithParams):
+ """
+ Trait for parameter vectorColName.
+ """
+ vector_col = ParamInfo(
+ "vectorCol",
+ "Name of a vector column",
+ is_optional=False,
+ type_converter=TypeConverters.to_string)
+
+ def set_vector_col(self, v: str) -> 'HasVectorCol':
+ return super().set(self.vector_col, v)
+
+ def get_vector_col(self) -> str:
+ return super().get(self.vector_col)
+
+
+class WrapperTransformer(JavaTransformer, HasSelectedCols):
+ """
+ A Transformer wrappers Java Transformer.
+ """
+ @keyword
+ def __init__(self, *, selected_cols=None):
+ _j_obj = get_gateway().jvm.org.apache.flink.ml.pipeline.\
+ UserDefinedPipelineStages.SelectColumnTransformer()
+ super().__init__(_j_obj)
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+
+class PythonAddTransformer(Transformer, HasSelectedCols, HasOutputCol):
+ """
+ A Transformer which is implemented with Python. Output a column
+ contains the sum of all columns.
+ """
+ @keyword
+ def __init__(self, *, selected_cols=None, output_col=None):
+ super().__init__()
+ kwargs = self._input_kwargs
+ self._set(**kwargs)
+
+ def transform(self, table_env, table):
+ input_columns = self.get_selected_cols()
+ expr = "+".join(input_columns)
+ expr = expr + " as " + self.get_output_col()
+ return table.add_columns(expr)
+
+
+class PythonEstimator(Estimator, HasVectorCol, HasPredictionCol):
+
+ def __init__(self):
+ super().__init__()
+
+ def fit(self, table_env, table):
+ return PythonModel(
+ table_env,
+ table.select("max(features) as max_sum"),
+ self.get_prediction_col())
+
+
+class PythonModel(Model):
+
+ def __init__(self, table_env, model_data_table, output_col_name):
+ self._model_data_table = model_data_table
+ self._output_col_name = output_col_name
+ self.max_sum = 0
+ self.load_model(table_env)
+
+ def load_model(self, table_env):
+ """
+ Train the model to get the max_sum value which is used to predict data.
+ """
+ table_sink = source_sink_utils.TestRetractSink(["max_sum"], [DataTypes.BIGINT()])
+ table_env.register_table_sink("Model_Results", table_sink)
+ self._model_data_table.insert_into("Model_Results")
+ table_env.execute("load model")
+ actual = source_sink_utils.results()
+ self.max_sum = actual.apply(0)
+
+ def transform(self, table_env, table):
+ """
+ Use max_sum to predict input. Return turn if input value is bigger than max_sum
+ """
+ return table\
+ .add_columns("features > {} as {}".format(self.max_sum, self._output_col_name))\
+ .select("{}".format(self._output_col_name))
+
+
+class PythonPipelineTest(MLTestCase):
+
+ def test_java_transformer(self):
+ t_env = MLEnvironmentFactory().get_default().get_stream_table_environment()
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
+ t_env.register_table_sink("TransformerResults", table_sink)
+
+ source_table = t_env.from_elements([(1, 2, 3, 4), (4, 3, 2, 1)], ['a', 'b', 'c', 'd'])
+ transformer = WrapperTransformer(selected_cols=["a", "b"])
+ transformer\
+ .transform(t_env, source_table)\
+ .insert_into("TransformerResults")
+
+ # execute
+ t_env.execute('JavaPipelineITCase')
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["1,2", "4,3"])
+
+ def test_pipeline(self):
+ t_env = MLEnvironmentFactory().get_default().get_stream_table_environment()
+ train_table = t_env.from_elements(
+ [(1, 2), (1, 4), (1, 0), (10, 2), (10, 4), (10, 0)], ['a', 'b'])
+ serving_table = t_env.from_elements([(0, 0), (12, 3)], ['a', 'b'])
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['predict_result'],
+ [DataTypes.BOOLEAN()])
+ t_env.register_table_sink("PredictResults", table_sink)
+
+ # transformer, output features column which is the sum of a and b.
+ transformer = PythonAddTransformer(selected_cols=["a", "b"], output_col="features")
+
+ # estimator
+ estimator = PythonEstimator()\
+ .set_vector_col("features")\
+ .set_prediction_col("predict_result")
+
+ # pipeline
+ pipeline = Pipeline().append_stage(transformer).append_stage(estimator)
+ pipeline\
+ .fit(t_env, train_table)\
+ .transform(t_env, serving_table)\
+ .insert_into('PredictResults')
+ # execute
+ t_env.execute('PipelineITCase')
+
+ actual = source_sink_utils.results()
+ # the first input is false since 0 + 0 is smaller than the max_sum 14.
+ # the second input is true since 12 + 3 is bigger than the max_sum 14.
+ self.assert_equals(actual, ["false", "true"])
diff --git a/flink-python/pyflink/ml/tests/test_pipeline_stage.py b/flink-python/pyflink/ml/tests/test_pipeline_stage.py
new file mode 100644
index 0000000..cce53f2
--- /dev/null
+++ b/flink-python/pyflink/ml/tests/test_pipeline_stage.py
@@ -0,0 +1,89 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from abc import ABC, abstractmethod
+from pyflink.ml.api import Transformer, Estimator, MLEnvironmentFactory
+import unittest
+
+
+class PipelineStageTestBase(ABC):
+ """
+ The base class for testing the base implementation of pipeline stages, i.e. Estimators
+ and Transformers. This class is package private because we do not expect extension outside
+ of the package.
+ """
+
+ @abstractmethod
+ def create_pipeline_stage(self):
+ pass
+
+
+class TransformerBaseTest(PipelineStageTestBase, unittest.TestCase):
+ """
+ Test for TransformerBase.
+ """
+ def test_fit_table(self):
+ id = MLEnvironmentFactory.get_new_ml_environment_id()
+ env = MLEnvironmentFactory.get(id)
+ table = env.get_stream_table_environment().from_elements([(1, 2, 3)])
+ transformer = self.create_pipeline_stage()
+ transformer.transform(env.get_stream_table_environment(), table)
+ self.assertTrue(transformer.transformed)
+
+ def create_pipeline_stage(self):
+ return self.FakeTransFormer()
+
+ class FakeTransFormer(Transformer):
+ """
+ This fake transformer simply record which transform method is invoked.
+ """
+
+ def __init__(self):
+ self.transformed = False
+
+ def transform(self, table_env, table):
+ self.transformed = True
+ return table
+
+
+class EstimatorBaseTest(PipelineStageTestBase, unittest.TestCase):
+ """
+ Test for EstimatorBase.
+ """
+ def test_fit_table(self):
+ id = MLEnvironmentFactory.get_new_ml_environment_id()
+ env = MLEnvironmentFactory.get(id)
+ table = env.get_stream_table_environment().from_elements([(1, 2, 3)])
+ estimator = self.create_pipeline_stage()
+ estimator.fit(env.get_stream_table_environment(), table)
+ self.assertTrue(estimator.fitted)
+
+ def create_pipeline_stage(self):
+ return self.FakeEstimator()
+
+ class FakeEstimator(Estimator):
+ """
+ This fake estimator simply record which fit method is invoked.
+ """
+
+ def __init__(self):
+ self.fitted = False
+
+ def fit(self, table_env, table):
+ self.fitted = True
+ return None