You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sh...@apache.org on 2019/05/24 10:23:23 UTC
[flink] branch master updated: [FLINK-12473][ml] Add ML pipeline
and MLlib interface
This is an automated email from the ASF dual-hosted git repository.
shaoxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 3050957 [FLINK-12473][ml] Add ML pipeline and MLlib interface
3050957 is described below
commit 305095743ffe0bc39f76c1bda983da7d0df9e003
Author: Gen Luo <lu...@gmail.com>
AuthorDate: Fri May 24 17:05:23 2019 +0800
[FLINK-12473][ml] Add ML pipeline and MLlib interface
This closes #8402
---
flink-ml-parent/flink-ml-api/pom.xml | 45 ++++
.../org/apache/flink/ml/api/core/Estimator.java | 47 ++++
.../java/org/apache/flink/ml/api/core/Model.java | 39 +++
.../org/apache/flink/ml/api/core/Pipeline.java | 266 +++++++++++++++++++++
.../apache/flink/ml/api/core/PipelineStage.java | 56 +++++
.../org/apache/flink/ml/api/core/Transformer.java | 42 ++++
.../apache/flink/ml/api/misc/param/ParamInfo.java | 130 ++++++++++
.../flink/ml/api/misc/param/ParamInfoFactory.java | 129 ++++++++++
.../flink/ml/api/misc/param/ParamValidator.java | 39 +++
.../org/apache/flink/ml/api/misc/param/Params.java | 151 ++++++++++++
.../apache/flink/ml/api/misc/param/WithParams.java | 60 +++++
.../flink/ml/util/param/ExtractParamInfosUtil.java | 73 ++++++
.../org/apache/flink/ml/api/core/PipelineTest.java | 181 ++++++++++++++
.../org/apache/flink/ml/api/misc/ParamsTest.java | 72 ++++++
.../ml/util/param/ExtractParamInfosUtilTest.java | 104 ++++++++
flink-ml-parent/pom.xml | 39 +++
pom.xml | 1 +
17 files changed, 1474 insertions(+)
diff --git a/flink-ml-parent/flink-ml-api/pom.xml b/flink-ml-parent/flink-ml-api/pom.xml
new file mode 100644
index 0000000..f77f068
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/pom.xml
@@ -0,0 +1,45 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-ml-parent</artifactId>
+ <version>1.9-SNAPSHOT</version>
+ </parent>
+
+ <artifactId>flink-ml-api</artifactId>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-api-java</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-shaded-jackson</artifactId>
+ <version>${jackson.version}-${flink.shaded.version}</version>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
new file mode 100644
index 0000000..8e31d94
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+
+/**
+ * Estimators are {@link PipelineStage}s responsible for training and generating machine learning
+ * models.
+ *
+ * <p>The implementations are expected to take an input table as training samples and generate a
+ * {@link Model} which fits these samples.
+ *
+ * @param <E> class type of the Estimator implementation itself, used by {@link
+ * org.apache.flink.ml.api.misc.param.WithParams}.
+ * @param <M> class type of the {@link Model} this Estimator produces.
+ */
+@PublicEvolving
+public interface Estimator<E extends Estimator<E, M>, M extends Model<M>> extends PipelineStage<E> {
+
+ /**
+ * Train and produce a {@link Model} which fits the records in the given {@link Table}.
+ *
+ * @param tEnv the table environment to which the input table is bound.
+ * @param input the table with records to train the Model.
+ * @return a model trained to fit on the given Table.
+ */
+ M fit(TableEnvironment tEnv, Table input);
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
new file mode 100644
index 0000000..b52b6a9
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.Table;
+
+/**
+ * A model is an ordinary {@link Transformer} except how it is created. While ordinary transformers
+ * are defined by specifying the parameters directly, a model is usually generated by an {@link
+ * Estimator} when {@link Estimator#fit(org.apache.flink.table.api.TableEnvironment, Table)} is
+ * invoked.
+ *
+ * <p>We separate Model from {@link Transformer} in order to support potential
+ * model specific logic such as linking a Model to the {@link Estimator} from which the model was
+ * generated.
+ *
+ * @param <M> The class type of the Model implementation itself, used by {@link
+ * org.apache.flink.ml.api.misc.param.WithParams}
+ */
+@PublicEvolving
+public interface Model<M extends Model<M>> extends Transformer<M> {
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
new file mode 100644
index 0000000..c8326c4
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.util.InstantiationUtil;
+
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A pipeline is a linear workflow which chains {@link Estimator}s and {@link Transformer}s to
+ * execute an algorithm.
+ *
+ * <p>A pipeline itself can either act as an Estimator or a Transformer, depending on the stages it
+ * includes. More specifically:
+ * <ul>
+ * <li>
+ * If a Pipeline has an {@link Estimator}, one needs to call {@link Pipeline#fit(TableEnvironment,
+ * Table)} before use the pipeline as a {@link Transformer} . In this case the Pipeline is an {@link
+ * Estimator} and can produce a Pipeline as a {@link Model}.
+ * </li>
+ * <li>
+ * If a Pipeline has no {@link Estimator}, it is a {@link Transformer} and can be applied to a Table
+ * directly. In this case, {@link Pipeline#fit(TableEnvironment, Table)} will simply return the
+ * pipeline itself.
+ * </li>
+ * </ul>
+ *
+ * <p>In addition, a pipeline can also be used as a {@link PipelineStage} in another pipeline, just
+ * like an ordinary {@link Estimator} or {@link Transformer} as describe above.
+ */
+@PublicEvolving
+public final class Pipeline implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>,
+ Model<Pipeline> {
+ private static final long serialVersionUID = 1L;
+ private final List<PipelineStage> stages = new ArrayList<>();
+ private final Params params = new Params();
+
+ private int lastEstimatorIndex = -1;
+
+ public Pipeline() {
+ }
+
+ public Pipeline(String pipelineJson) {
+ this.loadJson(pipelineJson);
+ }
+
+ public Pipeline(List<PipelineStage> stages) {
+ for (PipelineStage s : stages) {
+ appendStage(s);
+ }
+ }
+
+ //is the stage a simple Estimator or pipeline with Estimator
+ private static boolean isStageNeedFit(PipelineStage stage) {
+ return (stage instanceof Pipeline && ((Pipeline) stage).needFit()) ||
+ (!(stage instanceof Pipeline) && stage instanceof Estimator);
+ }
+
+ /**
+ * Appends a PipelineStage to the tail of this pipeline. Pipeline is editable only via this
+ * method. The PipelineStage must be Estimator, Transformer, Model or Pipeline.
+ *
+ * @param stage the stage to be appended
+ */
+ public Pipeline appendStage(PipelineStage stage) {
+ if (isStageNeedFit(stage)) {
+ lastEstimatorIndex = stages.size();
+ } else if (!(stage instanceof Transformer)) {
+ throw new RuntimeException(
+ "All PipelineStages should be Estimator or Transformer, got:" +
+ stage.getClass().getSimpleName());
+ }
+ stages.add(stage);
+ return this;
+ }
+
+ /**
+ * Returns a list of all stages in this pipeline in order, the list is immutable.
+ *
+ * @return an immutable list of all stages in this pipeline in order.
+ */
+ public List<PipelineStage> getStages() {
+ return Collections.unmodifiableList(stages);
+ }
+
+ /**
+ * Check whether the pipeline acts as an {@link Estimator} or not. When the return value is
+ * true, that means this pipeline contains an {@link Estimator} and thus users must invoke
+ * {@link #fit(TableEnvironment, Table)} before they can use this pipeline as a {@link
+ * Transformer}. Otherwise, the pipeline can be used as a {@link Transformer} directly.
+ *
+ * @return {@code true} if this pipeline has an Estimator, {@code false} otherwise
+ */
+ public boolean needFit() {
+ return this.getIndexOfLastEstimator() >= 0;
+ }
+
+ public Params getParams() {
+ return params;
+ }
+
+ //find the last Estimator or Pipeline that needs fit in stages, -1 stand for no Estimator in Pipeline
+ private int getIndexOfLastEstimator() {
+ return lastEstimatorIndex;
+ }
+
+ /**
+ * Train the pipeline to fit on the records in the given {@link Table}.
+ *
+ * <p>This method go through all the {@link PipelineStage}s in order and does the following
+ * on each stage until the last {@link Estimator}(inclusive).
+ *
+ * <ul>
+ * <li>
+ * If a stage is an {@link Estimator}, invoke {@link Estimator#fit(TableEnvironment, Table)}
+ * with the input table to generate a {@link Model}, transform the the input table with the
+ * generated {@link Model} to get a result table, then pass the result table to the next stage
+ * as input.
+ * </li>
+ * <li>
+ * If a stage is a {@link Transformer}, invoke {@link Transformer#transform(TableEnvironment,
+ * Table)} on the input table to get a result table, and pass the result table to the next stage
+ * as input.
+ * </li>
+ * </ul>
+ *
+ * <p>After all the {@link Estimator}s are trained to fit their input tables, a new
+ * pipeline will be created with the same stages in this pipeline, except that all the
+ * Estimators in the new pipeline are replaced with their corresponding Models generated in the
+ * above process.
+ *
+ * <p>If there is no {@link Estimator} in the pipeline, the method returns a copy of this
+ * pipeline.
+ *
+ * @param tEnv the table environment to which the input table is bound.
+ * @param input the table with records to train the Pipeline.
+ * @return a pipeline with same stages as this Pipeline except all Estimators replaced with
+ * their corresponding Models.
+ */
+ @Override
+ public Pipeline fit(TableEnvironment tEnv, Table input) {
+ List<PipelineStage> transformStages = new ArrayList<>(stages.size());
+ int lastEstimatorIdx = getIndexOfLastEstimator();
+ for (int i = 0; i < stages.size(); i++) {
+ PipelineStage s = stages.get(i);
+ if (i <= lastEstimatorIdx) {
+ Transformer t;
+ boolean needFit = isStageNeedFit(s);
+ if (needFit) {
+ t = ((Estimator) s).fit(tEnv, input);
+ } else {
+ // stage is Transformer, guaranteed in appendStage() method
+ t = (Transformer) s;
+ }
+ transformStages.add(t);
+ input = t.transform(tEnv, input);
+ } else {
+ transformStages.add(s);
+ }
+ }
+ return new Pipeline(transformStages);
+ }
+
+ /**
+ * Generate a result table by applying all the stages in this pipeline to the input table in
+ * order.
+ *
+ * @param tEnv the table environment to which the input table is bound.
+ * @param input the table to be transformed
+ * @return a result table with all the stages applied to the input tables in order.
+ */
+ @Override
+ public Table transform(TableEnvironment tEnv, Table input) {
+ if (needFit()) {
+ throw new RuntimeException("Pipeline contains Estimator, need to fit first.");
+ }
+ for (PipelineStage s : stages) {
+ input = ((Transformer) s).transform(tEnv, input);
+ }
+ return input;
+ }
+
+ @Override
+ public String toJson() {
+ ObjectMapper mapper = new ObjectMapper();
+
+ List<Map<String, String>> stageJsons = new ArrayList<>();
+ for (PipelineStage s : getStages()) {
+ Map<String, String> stageMap = new HashMap<>();
+ stageMap.put("stageClassName", s.getClass().getTypeName());
+ stageMap.put("stageJson", s.toJson());
+ stageJsons.add(stageMap);
+ }
+
+ try {
+ return mapper.writeValueAsString(stageJsons);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize pipeline", e);
+ }
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void loadJson(String json) {
+ ObjectMapper mapper = new ObjectMapper();
+ List<Map<String, String>> stageJsons;
+ try {
+ stageJsons = mapper.readValue(json, List.class);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to deserialize pipeline json:" + json, e);
+ }
+ for (Map<String, String> stageMap : stageJsons) {
+ appendStage(restoreInnerStage(stageMap));
+ }
+ }
+
+ private PipelineStage<?> restoreInnerStage(Map<String, String> stageMap) {
+ String className = stageMap.get("stageClassName");
+ Class<?> clz;
+ try {
+ clz = Class.forName(className);
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("PipelineStage class " + className + " not exists", e);
+ }
+ InstantiationUtil.checkForInstantiation(clz);
+
+ PipelineStage<?> s;
+ try {
+ s = (PipelineStage<?>) clz.newInstance();
+ } catch (Exception e) {
+ throw new RuntimeException("Class is instantiable but failed to new an instance", e);
+ }
+
+ String stageJson = stageMap.get("stageJson");
+ s.loadJson(stageJson);
+ return s;
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
new file mode 100644
index 0000000..86bf0d3
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.ml.api.misc.param.ParamInfo;
+import org.apache.flink.ml.api.misc.param.WithParams;
+import org.apache.flink.ml.util.param.ExtractParamInfosUtil;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Base class for a stage in a pipeline. The interface is only a concept, and does not have any
+ * actual functionality. Its subclasses must be either Estimator or Transformer. No other classes
+ * should inherit this interface directly.
+ *
+ * <p>Each pipeline stage is with parameters, and requires a public empty constructor for
+ * restoration in Pipeline.
+ *
+ * @param <T> The class type of the PipelineStage implementation itself, used by {@link
+ * org.apache.flink.ml.api.misc.param.WithParams}
+ * @see WithParams
+ */
+interface PipelineStage<T extends PipelineStage<T>> extends WithParams<T>, Serializable {
+
+ default String toJson() {
+ return getParams().toJson();
+ }
+
+ default void loadJson(String json) {
+ List<ParamInfo> paramInfos = ExtractParamInfosUtil.extractParamInfos(this);
+ Map<String, Class<?>> classMap = new HashMap<>();
+ for (ParamInfo i : paramInfos) {
+ classMap.put(i.getName(), i.getValueClass());
+ }
+ getParams().loadJson(json, classMap);
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
new file mode 100644
index 0000000..9435a61
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+
+/**
+ * A transformer is a {@link PipelineStage} that transforms an input {@link Table} to a result
+ * {@link Table}.
+ *
+ * @param <T> The class type of the Transformer implementation itself, used by {@link
+ * org.apache.flink.ml.api.misc.param.WithParams}
+ */
+@PublicEvolving
+public interface Transformer<T extends Transformer<T>> extends PipelineStage<T> {
+ /**
+ * Applies the transformer on the input table, and returns the result table.
+ *
+ * @param tEnv the table environment to which the input table is bound.
+ * @param input the table to be transformed
+ * @return the transformed table
+ */
+ Table transform(TableEnvironment tEnv, Table input);
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
new file mode 100644
index 0000000..994576f
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * Definition of a parameter, including name, type, default value, validator and so on.
+ *
+ * <p>This class is provided to unify the interaction with parameters.
+ *
+ * @param <V> the type of the param value
+ */
+@PublicEvolving
+public class ParamInfo<V> {
+ private final String name;
+ private final String[] alias;
+ private final String description;
+ private final boolean isOptional;
+ private final boolean hasDefaultValue;
+ private final V defaultValue;
+ private final ParamValidator<V> validator;
+ private final Class<V> valueClass;
+
+ ParamInfo(String name, String[] alias, String description, boolean isOptional,
+ boolean hasDefaultValue, V defaultValue,
+ ParamValidator<V> validator, Class<V> valueClass) {
+ this.name = name;
+ this.alias = alias;
+ this.description = description;
+ this.isOptional = isOptional;
+ this.hasDefaultValue = hasDefaultValue;
+ this.defaultValue = defaultValue;
+ this.validator = validator;
+ this.valueClass = valueClass;
+ }
+
+ /**
+ * Returns the name of the parameter. The name must be unique in the stage the ParamInfo
+ * belongs to.
+ *
+ * @return the name of the parameter
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Returns the aliases of the parameter. The alias will be an empty string array by default.
+ *
+ * @return the aliases of the parameter
+ */
+ public String[] getAlias() {
+ return alias;
+ }
+
+ /**
+ * Returns the description of the parameter.
+ *
+ * @return the description of the parameter
+ */
+ public String getDescription() {
+ return description;
+ }
+
+ /**
+ * Returns whether the parameter is optional.
+ *
+ * @return {@code true} if the param is optional, {@code false} otherwise
+ */
+ public boolean isOptional() {
+ return isOptional;
+ }
+
+ /**
+ * Returns whether the parameter has a default value. Since {@code null} may also be a valid
+ * default value of a parameter, the return of getDefaultValue may be {@code null} even when
+ * this method returns true.
+ *
+ * @return {@code true} if the param is has a default value(even if it's a {@code null}), {@code
+ * false} otherwise
+ */
+ public boolean hasDefaultValue() {
+ return hasDefaultValue;
+ }
+
+ /**
+ * Returns the default value of the parameter. The default value should be defined whenever
+ * possible. The default value can be a {@code null} even if hasDefaultValue returns true.
+ *
+ * @return the default value of the param, {@code null} if not defined
+ */
+ public V getDefaultValue() {
+ return defaultValue;
+ }
+
+ /**
+ * Returns the validator to validate the value of the parameter.
+ *
+ * @return the validator to validate the value of the parameter.
+ */
+ public ParamValidator<V> getValidator() {
+ return validator;
+ }
+
+ /**
+ * Returns the class of the param value. It's usually needed in serialization.
+ *
+ * @return the class of the param value
+ */
+ public Class<V> getValueClass() {
+ return valueClass;
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfoFactory.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfoFactory.java
new file mode 100644
index 0000000..4c49580
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfoFactory.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc.param;
+
+/**
+ * Factory to create ParamInfo, all ParamInfos should be created via this class.
+ */
+public class ParamInfoFactory {
+ /**
+ * Returns a ParamInfoBuilder to configure and build a new ParamInfo.
+ *
+ * @param name name of the new ParamInfo
+ * @param valueClass value class of the new ParamInfo
+ * @param <V> value type of the new ParamInfo
+ * @return a ParamInfoBuilder
+ */
+ public static <V> ParamInfoBuilder<V> createParamInfo(String name, Class<V> valueClass) {
+ return new ParamInfoBuilder<>(name, valueClass);
+ }
+
+ /**
+ * Builder to build a new ParamInfo. Builder is created by ParamInfoFactory with name and
+ * valueClass set.
+ *
+ * @param <V> value type of the new ParamInfo
+ */
+ public static class ParamInfoBuilder<V> {
+ private String name;
+ private String[] alias = new String[0];
+ private String description;
+ private boolean isOptional = true;
+ private boolean hasDefaultValue = false;
+ private V defaultValue;
+ private ParamValidator<V> validator;
+ private Class<V> valueClass;
+
+ ParamInfoBuilder(String name, Class<V> valueClass) {
+ this.name = name;
+ this.valueClass = valueClass;
+ }
+
+ /**
+ * Sets the aliases of the parameter.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setAlias(String[] alias) {
+ this.alias = alias;
+ return this;
+ }
+
+ /**
+ * Sets the description of the parameter.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setDescription(String description) {
+ this.description = description;
+ return this;
+ }
+
+ /**
+ * Sets the flag indicating the parameter is optional. The parameter is optional by default.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setOptional() {
+ this.isOptional = true;
+ return this;
+ }
+
+ /**
+ * Sets the flag indicating the parameter is required.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setRequired() {
+ this.isOptional = false;
+ return this;
+ }
+
+ /**
+ * Sets the flag indicating the parameter has default value, and sets the default value.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setHasDefaultValue(V defaultValue) {
+ this.hasDefaultValue = true;
+ this.defaultValue = defaultValue;
+ return this;
+ }
+
+ /**
+ * Sets the validator to validate the parameter value set by users.
+ *
+ * @return the builder itself
+ */
+ public ParamInfoBuilder<V> setValidator(ParamValidator<V> validator) {
+ this.validator = validator;
+ return this;
+ }
+
+ /**
+ * Builds the defined ParamInfo and returns it. The ParamInfo will be immutable.
+ *
+ * @return the defined ParamInfo
+ */
+ public ParamInfo<V> build() {
+ return new ParamInfo<>(name, alias, description, isOptional, hasDefaultValue,
+ defaultValue, validator, valueClass);
+ }
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamValidator.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamValidator.java
new file mode 100644
index 0000000..c95b146
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamValidator.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.Serializable;
+
+/**
+ * An interface used by {@link ParamInfo} to do validation when a parameter value is set.
+ *
+ * @param <V> the type of the value to validate
+ */
+@PublicEvolving
+public interface ParamValidator<V> extends Serializable {
+ /**
+ * Validates a parameter value.
+ *
+ * @param value value to validate
+ * @return {@code true} if the value is valid, {@code false} otherwise
+ */
+ boolean validate(V value);
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
new file mode 100644
index 0000000..0c1e0d8
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The map-like container class for parameter. This class is provided to unify the interaction with
+ * parameters.
+ */
+@PublicEvolving
+public class Params implements Serializable {
+ private final Map<String, Object> paramMap = new HashMap<>();
+
+ /**
+ * Returns the value of the specific parameter, or default value defined in the {@code info} if
+ * this Params doesn't contain the param.
+ *
+ * @param info the info of the specific parameter, usually with default value
+ * @param <V> the type of the specific parameter
+ * @return the value of the specific parameter, or default value defined in the {@code info} if
+ * this Params doesn't contain the parameter
+ * @throws RuntimeException if the Params doesn't contains the specific parameter, while the
+ * param is not optional but has no default value in the {@code info}
+ */
+ @SuppressWarnings("unchecked")
+ public <V> V get(ParamInfo<V> info) {
+ V value = (V) paramMap.getOrDefault(info.getName(), info.getDefaultValue());
+ if (value == null && !info.isOptional() && !info.hasDefaultValue()) {
+ throw new RuntimeException(info.getName() +
+ " not exist which is not optional and don't have a default value");
+ }
+ return value;
+ }
+
+ /**
+ * Set the value of the specific parameter.
+ *
+ * @param info the info of the specific parameter to set.
+ * @param value the value to be set to the specific parameter.
+ * @param <V> the type of the specific parameter.
+ * @return the previous value of the specific parameter, or null if this Params didn't contain
+ * the parameter before
+ * @throws RuntimeException if the {@code info} has a validator and the {@code value} is
+ * evaluated as illegal by the validator
+ */
+ public <V> Params set(ParamInfo<V> info, V value) {
+ if (!info.isOptional() && value == null) {
+ throw new RuntimeException(
+ "Setting " + info.getName() + " as null while it's not a optional param");
+ }
+ if (value == null) {
+ remove(info);
+ return this;
+ }
+
+ if (info.getValidator() != null && !info.getValidator().validate(value)) {
+ throw new RuntimeException(
+ "Setting " + info.getName() + " as a invalid value:" + value);
+ }
+ paramMap.put(info.getName(), value);
+ return this;
+ }
+
+ /**
+ * Removes the specific parameter from this Params.
+ *
+ * @param info the info of the specific parameter to remove
+ * @param <V> the type of the specific parameter
+ */
+ public <V> void remove(ParamInfo<V> info) {
+ paramMap.remove(info.getName());
+ }
+
+ /**
+ * Creates and returns a deep clone of this Params.
+ *
+ * @return a deep clone of this Params
+ */
+ public Params clone() {
+ Params newParams = new Params();
+ newParams.paramMap.putAll(this.paramMap);
+ return newParams;
+ }
+
+ /**
+ * Returns a json containing all parameters in this Params. The json should be human-readable if
+ * possible.
+ *
+ * @return a json containing all parameters in this Params
+ */
+ public String toJson() {
+ ObjectMapper mapper = new ObjectMapper();
+ Map<String, String> stringMap = new HashMap<>();
+ try {
+ for (Map.Entry<String, Object> e : paramMap.entrySet()) {
+ stringMap.put(e.getKey(), mapper.writeValueAsString(e.getValue()));
+ }
+ return mapper.writeValueAsString(stringMap);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize params to json", e);
+ }
+ }
+
+ /**
+ * Restores the parameters from the given json. The parameters should be exactly the same with
+ * the one who was serialized to the input json after the restoration. The class mapping of the
+ * parameters in the json is required because it is hard to directly restore a param of a user
+ * defined type. Params will be treated as String if it doesn't exist in the {@code classMap}.
+ *
+ * @param json the json String to restore from
+ * @param classMap the classes of the parameters contained in the json
+ */
+ @SuppressWarnings("unchecked")
+ public void loadJson(String json, Map<String, Class<?>> classMap) {
+ ObjectMapper mapper = new ObjectMapper();
+ try {
+ Map<String, String> m = mapper.readValue(json, Map.class);
+ for (Map.Entry<String, String> e : m.entrySet()) {
+ Class<?> valueClass = classMap.getOrDefault(e.getKey(), String.class);
+ paramMap.put(e.getKey(), mapper.readValue(e.getValue(), valueClass));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to deserialize json:" + json, e);
+ }
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/WithParams.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/WithParams.java
new file mode 100644
index 0000000..5e87508
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/WithParams.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc.param;
+
+/**
+ * Parameters are widely used in machine learning realm. This class defines a common interface to
+ * interact with classes with parameters.
+ *
+ * @param <T> the actual type of this WithParams, as the return type of setter
+ */
+public interface WithParams<T> {
+ /**
+ * Returns the all the parameters.
+ *
+ * @return all the parameters.
+ */
+ Params getParams();
+
+ /**
+ * Set the value of a specific parameter.
+ *
+ * @param info the info of the specific param to set
+ * @param value the value to be set to the specific param
+ * @param <V> the type of the specific param
+ * @return the WithParams itself
+ */
+ @SuppressWarnings("unchecked")
+ default <V> T set(ParamInfo<V> info, V value) {
+ getParams().set(info, value);
+ return (T) this;
+ }
+
+ /**
+ * Returns the value of the specific param.
+ *
+ * @param info the info of the specific param, usually with default value
+ * @param <V> the type of the specific param
+ * @return the value of the specific param, or default value defined in the {@code info} if the
+ * inner Params doesn't contains this param
+ */
+ default <V> V get(ParamInfo<V> info) {
+ return getParams().get(info);
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/util/param/ExtractParamInfosUtil.java b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/util/param/ExtractParamInfosUtil.java
new file mode 100644
index 0000000..c59d2e0
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/util/param/ExtractParamInfosUtil.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util.param;
+
+import org.apache.flink.ml.api.misc.param.ParamInfo;
+import org.apache.flink.ml.api.misc.param.WithParams;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Utility to extract all ParamInfos defined in a WithParams, mainly used in persistence.
+ */
+public final class ExtractParamInfosUtil {
+ private static final Logger LOG = LoggerFactory.getLogger(ExtractParamInfosUtil.class);
+
+ /**
+ * Extracts all ParamInfos defined in the given WithParams, including those in its superclasses
+ * and interfaces.
+ *
+ * @param s the WithParams to extract ParamInfos from
+ * @return the list of all ParamInfos defined in s
+ */
+ public static List<ParamInfo> extractParamInfos(WithParams s) {
+ return extractParamInfos(s, s.getClass());
+ }
+
+ private static List<ParamInfo> extractParamInfos(WithParams s, Class clz) {
+ List<ParamInfo> result = new ArrayList<>();
+ if (clz == null) {
+ return result;
+ }
+
+ Field[] fields = clz.getDeclaredFields();
+ for (Field f : fields) {
+ f.setAccessible(true);
+ if (ParamInfo.class.isAssignableFrom(f.getType())) {
+ try {
+ result.add((ParamInfo) f.get(s));
+ } catch (IllegalAccessException e) {
+ LOG.warn("Failed to extract param info {}, ignore it", f.getName(), e);
+ }
+ }
+ }
+
+ result.addAll(extractParamInfos(s, clz.getSuperclass()));
+ for (Class c : clz.getInterfaces()) {
+ result.addAll(extractParamInfos(s, c));
+ }
+
+ return result;
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
new file mode 100644
index 0000000..fc82634
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.ml.api.misc.param.ParamInfo;
+import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+/**
+ * Tests the behavior of {@link Pipeline}.
+ */
+public class PipelineTest {
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ @Test
+ public void testPipelineBehavior() {
+ Pipeline pipeline = new Pipeline();
+ pipeline.appendStage(new MockTransformer("a"));
+ pipeline.appendStage(new MockEstimator("b"));
+ pipeline.appendStage(new MockEstimator("c"));
+ pipeline.appendStage(new MockTransformer("d"));
+ assert describePipeline(pipeline).equals("a_b_c_d");
+
+ Pipeline pipelineModel = pipeline.fit(null, null);
+ assert describePipeline(pipelineModel).equals("a_mb_mc_d");
+
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Pipeline contains Estimator, need to fit first.");
+ pipeline.transform(null, null);
+ }
+
+ @Test
+ public void testPipelineRestore() {
+ Pipeline pipeline = new Pipeline();
+ pipeline.appendStage(new MockTransformer("a"));
+ pipeline.appendStage(new MockEstimator("b"));
+ pipeline.appendStage(new MockEstimator("c"));
+ pipeline.appendStage(new MockTransformer("d"));
+ String pipelineJson = pipeline.toJson();
+
+ Pipeline restoredPipeline = new Pipeline(pipelineJson);
+ assert describePipeline(restoredPipeline).equals("a_b_c_d");
+
+ Pipeline pipelineModel = pipeline.fit(null, null);
+ String modelJson = pipelineModel.toJson();
+
+ Pipeline restoredPipelineModel = new Pipeline(modelJson);
+ assert describePipeline(restoredPipelineModel).equals("a_mb_mc_d");
+ }
+
+ private static String describePipeline(Pipeline p) {
+ StringBuilder res = new StringBuilder();
+ for (PipelineStage s : p.getStages()) {
+ if (res.length() != 0) {
+ res.append("_");
+ }
+ res.append(((SelfDescribe) s).describe());
+ }
+ return res.toString();
+ }
+
+ /**
+ * Interface to describe a class with a string, only for pipeline test.
+ */
+ private interface SelfDescribe {
+ ParamInfo<String> DESCRIPTION = ParamInfoFactory.createParamInfo("description",
+ String.class).build();
+
+ String describe();
+ }
+
+ /**
+ * Mock estimator for pipeline test.
+ */
+ public static class MockEstimator implements Estimator<MockEstimator, MockModel>, SelfDescribe {
+ private final Params params = new Params();
+
+ public MockEstimator() {
+ }
+
+ MockEstimator(String description) {
+ set(DESCRIPTION, description);
+ }
+
+ @Override
+ public MockModel fit(TableEnvironment tEnv, Table input) {
+ return new MockModel("m" + describe());
+ }
+
+ @Override
+ public Params getParams() {
+ return params;
+ }
+
+ @Override
+ public String describe() {
+ return get(DESCRIPTION);
+ }
+ }
+
+ /**
+ * Mock transformer for pipeline test.
+ */
+ public static class MockTransformer implements Transformer<MockTransformer>, SelfDescribe {
+ private final Params params = new Params();
+
+ public MockTransformer() {
+ }
+
+ MockTransformer(String description) {
+ set(DESCRIPTION, description);
+ }
+
+ @Override
+ public Table transform(TableEnvironment tEnv, Table input) {
+ return input;
+ }
+
+ @Override
+ public Params getParams() {
+ return params;
+ }
+
+ @Override
+ public String describe() {
+ return get(DESCRIPTION);
+ }
+ }
+
+ /**
+ * Mock model for pipeline test.
+ */
+ public static class MockModel implements Model<MockModel>, SelfDescribe {
+ private final Params params = new Params();
+
+ public MockModel() {
+ }
+
+ MockModel(String description) {
+ set(DESCRIPTION, description);
+ }
+
+ @Override
+ public Table transform(TableEnvironment tEnv, Table input) {
+ return input;
+ }
+
+ @Override
+ public Params getParams() {
+ return params;
+ }
+
+ @Override
+ public String describe() {
+ return get(DESCRIPTION);
+ }
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
new file mode 100644
index 0000000..8bdf95b
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.misc;
+
+import org.apache.flink.ml.api.misc.param.ParamInfo;
+import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
+import org.apache.flink.ml.api.misc.param.Params;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+/**
+ * Test for the behavior and validator of {@link Params}.
+ */
+public class ParamsTest {
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ @Test
+ public void testDefaultBehavior() {
+ Params params = new Params();
+
+ ParamInfo<String> optionalWithoutDefault =
+ ParamInfoFactory.createParamInfo("a", String.class).build();
+ assert params.get(optionalWithoutDefault) == null;
+
+ ParamInfo<String> optionalWithDefault =
+ ParamInfoFactory.createParamInfo("a", String.class).setHasDefaultValue("def").build();
+ assert params.get(optionalWithDefault).equals("def");
+
+ ParamInfo<String> requiredWithDefault =
+ ParamInfoFactory.createParamInfo("a", String.class).setRequired()
+ .setHasDefaultValue("def").build();
+ assert params.get(requiredWithDefault).equals("def");
+
+ ParamInfo<String> requiredWithoutDefault =
+ ParamInfoFactory.createParamInfo("a", String.class).setRequired().build();
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("a not exist which is not optional and don't have a default value");
+ params.get(requiredWithoutDefault);
+ }
+
+ @Test
+ public void testValidator() {
+ Params params = new Params();
+
+ ParamInfo<Integer> intParam =
+ ParamInfoFactory.createParamInfo("a", Integer.class).setValidator(i -> i > 0).build();
+ params.set(intParam, 1);
+
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Setting a as a invalid value:0");
+ params.set(intParam, 0);
+ }
+}
diff --git a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/util/param/ExtractParamInfosUtilTest.java b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/util/param/ExtractParamInfosUtilTest.java
new file mode 100644
index 0000000..5f467e9
--- /dev/null
+++ b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/util/param/ExtractParamInfosUtilTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util.param;
+
+import org.apache.flink.ml.api.misc.param.ParamInfo;
+import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.api.misc.param.WithParams;
+
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Test for {@link ExtractParamInfosUtil}.
+ */
+public class ExtractParamInfosUtilTest {
+
+ @Test
+ public void testExtractParamInfos() {
+ List<ParamInfo> noParamInfos =
+ ExtractParamInfosUtil.extractParamInfos(new WithNoParamInfo());
+ assert noParamInfos.isEmpty();
+
+ List<ParamInfo> classParamInfos =
+ ExtractParamInfosUtil.extractParamInfos(new WithTestParamInfo());
+ assert classParamInfos.size() == 1 && classParamInfos.get(0).getName().equals("KSC");
+
+ List<ParamInfo> allParamInfos =
+ ExtractParamInfosUtil.extractParamInfos(new TestParamInfoWithInheritedParamInfos());
+ String[] sortedCorrectParamNames = new String[]{"KCP", "KI", "KSC"};
+ assert allParamInfos.size() == 3 && Arrays.equals(sortedCorrectParamNames,
+ allParamInfos.stream().map(ParamInfo::getName).sorted().toArray(String[]::new));
+ }
+
+ /**
+ * Mock WithParams implementation with no ParamInfo. Only for test.
+ */
+ public static class WithNoParamInfo implements WithParams<WithNoParamInfo> {
+
+ @Override
+ public Params getParams() {
+ return null;
+ }
+ }
+
+ /**
+ * Mock WithParams implementation with one ParamInfo. Only for test.
+ * @param <T> subclass of WithTestParamInfo
+ */
+ public static class WithTestParamInfo<T extends WithTestParamInfo> implements WithParams<T> {
+ public static final ParamInfo<String> KSC = ParamInfoFactory
+ .createParamInfo("KSC", String.class)
+ .setDescription("key from super class").build();
+
+ @Override
+ public Params getParams() {
+ return null;
+ }
+ }
+
+ /**
+ * Mock interface extending WithParams with one ParamInfo. Only for test.
+ * @param <T> implementation class of InterfaceWithParamInfo
+ */
+ public interface InterfaceWithParamInfo<T extends InterfaceWithParamInfo>
+ extends WithParams<T> {
+ ParamInfo<String> KI = ParamInfoFactory.createParamInfo("KI", String.class)
+ .setDescription("key from interface").build();
+ }
+
+ /**
+ * Mock WithParams inheriting ParamInfos from superclass and interface. Only for test.
+ */
+ public static class TestParamInfoWithInheritedParamInfos
+ extends WithTestParamInfo<TestParamInfoWithInheritedParamInfos>
+ implements InterfaceWithParamInfo<TestParamInfoWithInheritedParamInfos> {
+ private static final ParamInfo<String> KCP = ParamInfoFactory
+ .createParamInfo("KCP", String.class)
+ .setDescription("key in the class which is private").build();
+
+ @Override
+ public Params getParams() {
+ return null;
+ }
+ }
+}
diff --git a/flink-ml-parent/pom.xml b/flink-ml-parent/pom.xml
new file mode 100644
index 0000000..6ca2ebc
--- /dev/null
+++ b/flink-ml-parent/pom.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-parent</artifactId>
+ <version>1.9-SNAPSHOT</version>
+ <relativePath>..</relativePath>
+ </parent>
+
+ <artifactId>flink-ml-parent</artifactId>
+
+ <packaging>pom</packaging>
+
+ <modules>
+ <module>flink-ml-api</module>
+ </modules>
+</project>
diff --git a/pom.xml b/pom.xml
index 086e452..38b3601 100644
--- a/pom.xml
+++ b/pom.xml
@@ -86,6 +86,7 @@ under the License.
<module>flink-fs-tests</module>
<module>flink-docs</module>
<module>flink-python</module>
+ <module>flink-ml-parent</module>
</modules>
<properties>