You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by jq...@apache.org on 2021/11/09 12:12:18 UTC
[flink-ml] 02/02: [FLINK-24354][FLIP-174] Improve the WithParams
interface
This is an automated email from the ASF dual-hosted git repository.
jqin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 9c44eef25970338fe32dcf77ce45efac74c4324f
Author: Dong Lin <li...@gmail.com>
AuthorDate: Sun Sep 26 21:40:59 2021 +0800
[FLINK-24354][FLIP-174] Improve the WithParams interface
---
flink-ml-api/pom.xml | 15 +
.../org/apache/flink/ml/api/core/Pipeline.java | 23 +-
.../apache/flink/ml/api/core/PipelineModel.java | 23 +-
.../java/org/apache/flink/ml/api/core/Stage.java | 2 +-
.../org/apache/flink/ml/param/BooleanParam.java | 35 ++
.../apache/flink/ml/param/DoubleArrayParam.java | 35 ++
.../org/apache/flink/ml/param/DoubleParam.java | 35 ++
.../org/apache/flink/ml/param/FloatArrayParam.java | 35 ++
.../java/org/apache/flink/ml/param/FloatParam.java | 32 ++
.../org/apache/flink/ml/param/IntArrayParam.java | 35 ++
.../java/org/apache/flink/ml/param/IntParam.java | 35 ++
.../org/apache/flink/ml/param/LongArrayParam.java | 35 ++
.../java/org/apache/flink/ml/param/LongParam.java | 32 ++
.../main/java/org/apache/flink/ml/param/Param.java | 98 ++++++
.../org/apache/flink/ml/param/ParamValidator.java | 40 +++
.../org/apache/flink/ml/param/ParamValidators.java | 98 ++++++
.../apache/flink/ml/param/StringArrayParam.java | 35 ++
.../org/apache/flink/ml/param/StringParam.java | 35 ++
.../java/org/apache/flink/ml/param/WithParams.java | 135 ++++++++
.../java/org/apache/flink/ml/util/ParamUtils.java | 89 +++++
.../org/apache/flink/ml/util/ReadWriteUtils.java | 279 +++++++++++++++
.../apache/flink/ml/api/core/ExampleStages.java | 244 ++++++++++++++
.../org/apache/flink/ml/api/core/PipelineTest.java | 202 +++++------
.../org/apache/flink/ml/api/core/StageTest.java | 375 +++++++++++++++++++++
pom.xml | 2 -
25 files changed, 1863 insertions(+), 141 deletions(-)
diff --git a/flink-ml-api/pom.xml b/flink-ml-api/pom.xml
index 81fdcc7..ddfc659 100644
--- a/flink-ml-api/pom.xml
+++ b/flink-ml-api/pom.xml
@@ -38,6 +38,21 @@ under the License.
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-planner_${scala.binary.version}</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-test-utils_${scala.binary.version}</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ </dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-shaded-jackson</artifactId>
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index a5fed01..f1e5d0c 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -20,13 +20,17 @@ package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
/**
* A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be
@@ -36,10 +40,11 @@ import java.util.List;
public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
private static final long serialVersionUID = 6384850154817512318L;
private final List<Stage<?>> stages;
- private final Params params = new Params();
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
public Pipeline(List<Stage<?>> stages) {
this.stages = stages;
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
/**
@@ -97,17 +102,17 @@ public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
}
@Override
- public void save(String path) throws IOException {
- throw new UnsupportedOperationException();
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
}
- public static Pipeline load(String path) throws IOException {
- throw new UnsupportedOperationException();
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.savePipeline(this, stages, path);
}
- @Override
- public Params getParams() {
- return params;
+ public static Pipeline load(String path) throws IOException {
+ return new Pipeline(ReadWriteUtils.loadPipeline(path, Pipeline.class.getName()));
}
/**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
index 704fa8e..45bb757 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -20,12 +20,16 @@ package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import java.io.IOException;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
/**
* A PipelineModel acts as a Model. It consists of an ordered list of stages, each of which could be
@@ -35,10 +39,11 @@ import java.util.List;
public final class PipelineModel implements Model<PipelineModel> {
private static final long serialVersionUID = 6184950154217411318L;
private final List<Stage<?>> stages;
- private final Params params = new Params();
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
public PipelineModel(List<Stage<?>> stages) {
this.stages = stages;
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
/**
@@ -58,17 +63,17 @@ public final class PipelineModel implements Model<PipelineModel> {
}
@Override
- public void save(String path) throws IOException {
- throw new UnsupportedOperationException();
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
}
- public static PipelineModel load(String path) throws IOException {
- throw new UnsupportedOperationException();
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.savePipeline(this, stages, path);
}
- @Override
- public Params getParams() {
- return params;
+ public static PipelineModel load(String path) throws IOException {
+ return new PipelineModel(ReadWriteUtils.loadPipeline(path, PipelineModel.class.getName()));
}
/**
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
index 551c5e5..168599b 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -19,7 +19,7 @@
package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.ml.api.misc.param.WithParams;
+import org.apache.flink.ml.param.WithParams;
import java.io.IOException;
import java.io.Serializable;
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java
new file mode 100644
index 0000000..dd96ebe
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the boolean parameter. */
+public class BooleanParam extends Param<Boolean> {
+
+ public BooleanParam(
+ String name,
+ String description,
+ Boolean defaultValue,
+ ParamValidator<Boolean> validator) {
+ super(name, Boolean.class, description, defaultValue, validator);
+ }
+
+ public BooleanParam(String name, String description, Boolean defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java
new file mode 100644
index 0000000..b86dd00
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the double array parameter. */
+public class DoubleArrayParam extends Param<Double[]> {
+
+ public DoubleArrayParam(
+ String name,
+ String description,
+ Double[] defaultValue,
+ ParamValidator<Double[]> validator) {
+ super(name, Double[].class, description, defaultValue, validator);
+ }
+
+ public DoubleArrayParam(String name, String description, Double[] defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java
new file mode 100644
index 0000000..f6d4911
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the double parameter. */
+public class DoubleParam extends Param<Double> {
+
+ public DoubleParam(
+ String name,
+ String description,
+ Double defaultValue,
+ ParamValidator<Double> validator) {
+ super(name, Double.class, description, defaultValue, validator);
+ }
+
+ public DoubleParam(String name, String description, Double defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java
new file mode 100644
index 0000000..4224557
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the float array parameter. */
+public class FloatArrayParam extends Param<Float[]> {
+
+ public FloatArrayParam(
+ String name,
+ String description,
+ Float[] defaultValue,
+ ParamValidator<Float[]> validator) {
+ super(name, Float[].class, description, defaultValue, validator);
+ }
+
+ public FloatArrayParam(String name, String description, Float[] defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java
new file mode 100644
index 0000000..0de890c
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the float parameter. */
+public class FloatParam extends Param<Float> {
+
+ public FloatParam(
+ String name, String description, Float defaultValue, ParamValidator<Float> validator) {
+ super(name, Float.class, description, defaultValue, validator);
+ }
+
+ public FloatParam(String name, String description, Float defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java
new file mode 100644
index 0000000..4f7c630
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the integer array parameter. */
+public class IntArrayParam extends Param<Integer[]> {
+
+ public IntArrayParam(
+ String name,
+ String description,
+ Integer[] defaultValue,
+ ParamValidator<Integer[]> validator) {
+ super(name, Integer[].class, description, defaultValue, validator);
+ }
+
+ public IntArrayParam(String name, String description, Integer[] defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java
new file mode 100644
index 0000000..4178e22
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the integer parameter. */
+public class IntParam extends Param<Integer> {
+
+ public IntParam(
+ String name,
+ String description,
+ Integer defaultValue,
+ ParamValidator<Integer> validator) {
+ super(name, Integer.class, description, defaultValue, validator);
+ }
+
+ public IntParam(String name, String description, Integer defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java
new file mode 100644
index 0000000..5e4fc47
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the long array parameter. */
+public class LongArrayParam extends Param<Long[]> {
+
+ public LongArrayParam(
+ String name,
+ String description,
+ Long[] defaultValue,
+ ParamValidator<Long[]> validator) {
+ super(name, Long[].class, description, defaultValue, validator);
+ }
+
+ public LongArrayParam(String name, String description, Long[] defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java
new file mode 100644
index 0000000..3fd7dd8
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the long parameter. */
+public class LongParam extends Param<Long> {
+
+ public LongParam(
+ String name, String description, Long defaultValue, ParamValidator<Long> validator) {
+ super(name, Long.class, description, defaultValue, validator);
+ }
+
+ public LongParam(String name, String description, Long defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
new file mode 100644
index 0000000..b7a1aef
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.util.ReadWriteUtils;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * Definition of a parameter, including name, class, description, default value and the validator.
+ *
+ * @param <T> The class type of the parameter value.
+ */
+@PublicEvolving
+public class Param<T> implements Serializable {
+ private static final long serialVersionUID = 4396556083935765299L;
+
+ public final String name;
+ public final Class<T> clazz;
+ public final String description;
+ public final T defaultValue;
+ public final ParamValidator<T> validator;
+
+ public Param(
+ String name,
+ Class<T> clazz,
+ String description,
+ T defaultValue,
+ ParamValidator<T> validator) {
+ this.name = name;
+ this.clazz = clazz;
+ this.description = description;
+ this.defaultValue = defaultValue;
+ this.validator = validator;
+
+ if (defaultValue != null && !validator.validate(defaultValue)) {
+ throw new IllegalArgumentException(
+ "Parameter " + name + " is given an invalid value " + defaultValue);
+ }
+ }
+
+ /**
+ * Encodes the given object into a json-formatted string.
+ *
+ * @param value An object of class type T.
+ * @return A json-formatted string.
+ */
+ public String jsonEncode(T value) throws IOException {
+ return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value);
+ }
+
+ /**
+ * Decodes the given string into an object of class type T.
+ *
+ * @param json A json-formatted string.
+ * @return An object of class type T.
+ */
+ @SuppressWarnings("unchecked")
+ public T jsonDecode(String json) throws IOException {
+ return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof Param)) {
+ return false;
+ }
+ return ((Param<?>) obj).name.equals(name);
+ }
+
+ @Override
+ public int hashCode() {
+ return name.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return name;
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java
new file mode 100644
index 0000000..afdcd9a
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.Serializable;
+
+/**
+ * An interface to validate that a parameter value is valid.
+ *
+ * @param <T> The class type of the parameter value.
+ */
+@PublicEvolving
+public interface ParamValidator<T> extends Serializable {
+
+ /**
+ * Validate whether the parameter value is valid.
+ *
+ * @param value The parameter value.
+ * @return A boolean indicating whether the parameter value is valid.
+ */
+ boolean validate(T value);
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java
new file mode 100644
index 0000000..925ccb2
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+/** Factory methods for common validation functions on numerical values. */
+public class ParamValidators {
+
+ // Always return true.
+ public static <T> ParamValidator<T> alwaysTrue() {
+ return (value) -> true;
+ }
+
+ // Check if the parameter value is greater than lowerBound.
+ public static <T> ParamValidator<T> gt(double lowerBound) {
+ return (value) -> value != null && ((Number) value).doubleValue() > lowerBound;
+ }
+
+ // Check if the parameter value is greater than or equal to lowerBound.
+ public static <T> ParamValidator<T> gtEq(double lowerBound) {
+ return (value) -> value != null && ((Number) value).doubleValue() >= lowerBound;
+ }
+
+ // Check if the parameter value is less than upperBound.
+ public static <T> ParamValidator<T> lt(double upperBound) {
+ return (value) -> value != null && ((Number) value).doubleValue() < upperBound;
+ }
+
+ // Check if the parameter value is less than or equal to upperBound.
+ public static <T> ParamValidator<T> ltEq(double upperBound) {
+ return (value) -> value != null && ((Number) value).doubleValue() <= upperBound;
+ }
+
+ /**
+ * Check if the parameter value is in the range from lowerBound to upperBound.
+ *
+ * @param lowerInclusive if true, range includes value = lowerBound
+ * @param upperInclusive if true, range includes value = upperBound
+ */
+ public static <T> ParamValidator<T> inRange(
+ double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) {
+ return new ParamValidator<T>() {
+ @Override
+ public boolean validate(T obj) {
+ if (obj == null) {
+ return false;
+ }
+ double value = ((Number) obj).doubleValue();
+ return (value >= lowerBound)
+ && (value <= upperBound)
+ && (lowerInclusive || value != lowerBound)
+ && (upperInclusive || value != upperBound);
+ }
+ };
+ }
+
+ // Check if the parameter value is in the range [lowerBound, upperBound].
+ public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound) {
+ return inRange(lowerBound, upperBound, true, true);
+ }
+
+ // Check if the parameter value is in the array of allowed values.
+ public static <T> ParamValidator<T> inArray(T... allowed) {
+ return new ParamValidator<T>() {
+ @Override
+ public boolean validate(T value) {
+ return value != null && ArrayUtils.contains(allowed, value);
+ }
+ };
+ }
+
+ // Check if the parameter value is not null.
+ public static <T> ParamValidator<T> notNull() {
+ return new ParamValidator<T>() {
+ @Override
+ public boolean validate(T value) {
+ return value != null;
+ }
+ };
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java
new file mode 100644
index 0000000..5062463
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the string array parameter. */
+public class StringArrayParam extends Param<String[]> {
+
+ public StringArrayParam(
+ String name,
+ String description,
+ String[] defaultValue,
+ ParamValidator<String[]> validator) {
+ super(name, String[].class, description, defaultValue, validator);
+ }
+
+ public StringArrayParam(String name, String description, String[] defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java
new file mode 100644
index 0000000..1736354
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+/** Class for the string parameter. */
+public class StringParam extends Param<String> {
+
+ public StringParam(
+ String name,
+ String description,
+ String defaultValue,
+ ParamValidator<String> validator) {
+ super(name, String.class, description, defaultValue, validator);
+ }
+
+ public StringParam(String name, String description, String defaultValue) {
+ this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java
new file mode 100644
index 0000000..f631c8e
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.param;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.util.ParamUtils;
+
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Interface for classes that take parameters. It provides APIs to set and get parameters.
+ *
+ * @param <T> The class type of WithParams implementation itself.
+ */
+@PublicEvolving
+public interface WithParams<T> {
+
+ /**
+ * Gets the parameter by its name.
+ *
+ * @param name The parameter name.
+ * @param <V> The class type of the parameter value.
+ * @return The parameter.
+ */
+ default <V> Param<V> getParam(String name) {
+ Optional<Param<?>> result =
+ getParamMap().keySet().stream().filter(param -> param.name.equals(name)).findAny();
+ return (Param<V>) result.orElse(null);
+ }
+
+ /**
+ * Sets the value of the parameter.
+ *
+ * @param param The parameter.
+ * @param value The parameter value.
+ * @return The WithParams instance itself.
+ */
+ @SuppressWarnings("unchecked")
+ default <V> T set(Param<V> param, V value) {
+ if (value != null && !param.clazz.isAssignableFrom(value.getClass())) {
+ throw new ClassCastException(
+ "Parameter "
+ + param.name
+ + " is given a value with incompatible class "
+ + value.getClass().getName());
+ }
+
+ if (!param.validator.validate(value)) {
+ if (value == null) {
+ throw new IllegalArgumentException(
+ "Parameter " + param.name + "'s value should not be null");
+ } else {
+ throw new IllegalArgumentException(
+ "Parameter "
+ + param.name
+ + " is given an invalid value "
+ + value.toString());
+ }
+ }
+ getParamMap().put(param, value);
+ return (T) this;
+ }
+
+ /**
+ * Gets the value of the parameter.
+ *
+ * @param param The parameter.
+ * @param <V> The class type of the parameter value.
+ * @return The parameter value.
+ */
+ @SuppressWarnings("unchecked")
+ default <V> V get(Param<V> param) {
+ Map<Param<?>, Object> paramMap = getParamMap();
+ V value = (V) paramMap.get(param);
+
+ if (value == null && !param.validator.validate(value)) {
+ throw new IllegalArgumentException(
+ "Parameter " + param.name + "'s value should not be null");
+ }
+
+ return value;
+ }
+
+ /**
+ * Returns a map which should contain value for every parameter that meets one of the following
+ * conditions.
+ *
+ * <p>1) set(...) has been called to set value for this parameter.
+ *
+ * <p>2) The parameter is a public final field of this WithParams instance. This includes fields
+ * inherited from its interfaces and super-classes.
+ *
+ * <p>The subclass which implements this interface could meet this requirement by returning a
+ * member field of the given map type, after having initialized this member field using the
+ * {@link ParamUtils#initializeMapWithDefaultValues(Map, WithParams)} method.
+ *
+ * @return A map which maps parameter definition to parameter value.
+ */
+ Map<Param<?>, Object> getParamMap();
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java
new file mode 100644
index 0000000..cdbe63d
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.WithParams;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Modifier;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/** Utility methods for reading and writing stages. */
+public class ParamUtils {
+ /**
+ * Updates the paramMap with default values of all public final Param-typed fields of the given
+ * instance. A parameter's value will not be updated if this parameter is already found in the
+ * map.
+ *
+ * <p>Note: This method should be called after all public final Param-typed fields of the given
+ * instance have been defined. A good choice is to call this method in the constructor of the
+ * given instance.
+ */
+ public static void initializeMapWithDefaultValues(
+ Map<Param<?>, Object> paramMap, WithParams<?> instance) {
+ List<Param<?>> defaultParams = getPublicFinalParamFields(instance);
+ for (Param<?> param : defaultParams) {
+ if (!paramMap.containsKey(param)) {
+ paramMap.put(param, param.defaultValue);
+ }
+ }
+ }
+
+ /**
+ * Finds all public final fields of the Param class type of the given object, including those
+ * fields inherited from its interfaces and super-classes, and returns those Param instances as
+ * a list.
+ *
+ * @param object the object whose public final Param-typed fields will be returned.
+ * @return a list of Param instances.
+ */
+ public static List<Param<?>> getPublicFinalParamFields(Object object) {
+ return getPublicFinalParamFields(object, object.getClass());
+ }
+
+ // A helper method that finds all public final fields of the Param class type of the given
+ // object and returns those Param instances as a list. The clazz specifies the object class.
+ private static List<Param<?>> getPublicFinalParamFields(Object object, Class<?> clazz) {
+ List<Param<?>> result = new ArrayList<>();
+ for (Field field : clazz.getDeclaredFields()) {
+ field.setAccessible(true);
+ if (Param.class.isAssignableFrom(field.getType())
+ && Modifier.isPublic(field.getModifiers())
+ && Modifier.isFinal(field.getModifiers())) {
+ try {
+ result.add((Param<?>) field.get(object));
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(
+ "Failed to extract param from field " + field.getName(), e);
+ }
+ }
+ }
+
+ if (clazz.getSuperclass() != null) {
+ result.addAll(getPublicFinalParamFields(object, clazz.getSuperclass()));
+ }
+ for (Class<?> cls : clazz.getInterfaces()) {
+ result.addAll(getPublicFinalParamFields(object, cls));
+ }
+ return result;
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
new file mode 100644
index 0000000..283c1e5
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.util.InstantiationUtil;
+
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Utility methods for reading and writing stages. */
+public class ReadWriteUtils {
+ public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ // A helper method that calls encodes the given parameter value to a json string. We can not
+ // call param.jsonEncode(value) directly because Param::jsonEncode(...) needs the actual type
+ // of the value.
+ private static <T> String jsonEncodeHelper(Param<T> param, Object value) throws IOException {
+ return param.jsonEncode((T) value);
+ }
+
+ // Converts Map<Param<?>, Object> to Map<String, String> which maps the parameter name to the
+ // string-encoded parameter value.
+ private static Map<String, String> jsonEncode(Map<Param<?>, Object> paramMap)
+ throws IOException {
+ Map<String, String> result = new HashMap<>(paramMap.size());
+ for (Map.Entry<Param<?>, Object> entry : paramMap.entrySet()) {
+ String json = jsonEncodeHelper(entry.getKey(), entry.getValue());
+ result.put(entry.getKey().name, json);
+ }
+ return result;
+ }
+
+ /**
+ * Saves the metadata of the given stage and the extra metadata to a file named `metadata` under
+ * the given path. The metadata of a stage includes the stage class name, parameter values etc.
+ *
+ * <p>Required: the metadata file under the given path should not exist.
+ *
+ * @param stage The stage instance.
+ * @param path The parent directory to save the stage metadata.
+ * @param extraMetadata The extra metadata to be saved.
+ */
+ public static void saveMetadata(Stage<?> stage, String path, Map<String, ?> extraMetadata)
+ throws IOException {
+ // Creates parent directories if not already created.
+ new File(path).mkdirs();
+
+ Map<String, Object> metadata = new HashMap<>(extraMetadata);
+ metadata.put("className", stage.getClass().getName());
+ metadata.put("timestamp", System.currentTimeMillis());
+ metadata.put("paramMap", jsonEncode(stage.getParamMap()));
+ // TODO: add version in the metadata.
+ String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata);
+
+ File metadataFile = new File(path, "metadata");
+ if (!metadataFile.createNewFile()) {
+ throw new IOException("File " + metadataFile.toString() + " already exists.");
+ }
+ try (BufferedWriter writer = new BufferedWriter(new FileWriter(metadataFile))) {
+ writer.write(metadataStr);
+ }
+ }
+
+ /**
+ * Saves the metadata of the given stage to a file named `metadata` under the given path. The
+ * metadata of a stage includes the stage class name, parameter values etc.
+ *
+ * <p>Required: the metadata file under the given path should not exist.
+ *
+ * @param stage The stage instance.
+ * @param path The parent directory to save the stage metadata.
+ */
+ public static void saveMetadata(Stage<?> stage, String path) throws IOException {
+ saveMetadata(stage, path, new HashMap<>());
+ }
+
+ /**
+ * Loads the metadata from the metadata file under the given path.
+ *
+ * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
+ * match the className of the previously saved stage.
+ *
+ * @param path The parent directory of the metadata file to read from.
+ * @param expectedClassName The expected class name of the stage.
+ * @return A map from metadata name to metadata value.
+ */
+ public static Map<String, ?> loadMetadata(String path, String expectedClassName)
+ throws IOException {
+ Path metadataPath = Paths.get(path, "metadata");
+ StringBuilder buffer = new StringBuilder();
+ try (BufferedReader br = new BufferedReader(new FileReader(metadataPath.toString()))) {
+ String line;
+ while ((line = br.readLine()) != null) {
+ if (!line.startsWith("#")) {
+ buffer.append(line);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ Map<String, ?> result = OBJECT_MAPPER.readValue(buffer.toString(), Map.class);
+
+ String className = (String) result.get("className");
+ if (!expectedClassName.isEmpty() && !expectedClassName.equals(className)) {
+ throw new RuntimeException(
+ "Class name "
+ + className
+ + " does not match the expected class name "
+ + expectedClassName
+ + ".");
+ }
+
+ return result;
+ }
+
+ // Returns a string with value {parentPath}/stages/{stageIdx}, where the stageIdx is prefixed
+ // with zero or more `0` to have the same length as numStages. The resulting string can be
+ // used as the directory to save a stage of the Pipeline or PipelineModel.
+ private static String getPathForPipelineStage(int stageIdx, int numStages, String parentPath) {
+ String format = String.format("%%0%dd", String.valueOf(numStages).length());
+ String fileName = String.format(format, stageIdx);
+ return Paths.get(parentPath, "stages", fileName).toString();
+ }
+
+ /**
+ * Saves a Pipeline or PipelineModel with the given list of stages to the given path.
+ *
+ * @param pipeline A Pipeline or PipelineModel instance.
+ * @param stages A list of stages of the given pipeline.
+ * @param path The parent directory to save the pipeline metadata and its stages.
+ */
+ public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages, String path)
+ throws IOException {
+ // Creates parent directories if not already created.
+ new File(path).mkdirs();
+
+ Map<String, Object> extraMetadata = new HashMap<>();
+ extraMetadata.put("numStages", stages.size());
+ saveMetadata(pipeline, path, extraMetadata);
+
+ int numStages = stages.size();
+ for (int i = 0; i < numStages; i++) {
+ String stagePath = getPathForPipelineStage(i, numStages, path);
+ stages.get(i).save(stagePath);
+ }
+ }
+
+ /**
+ * Loads the stages of a Pipeline or PipelineModel from the given path.
+ *
+ * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
+ * match the className of the previously saved Pipeline or PipelineModel.
+ *
+ * @param path The parent directory to load the pipeline metadata and its stages.
+ * @param expectedClassName The expected class name of the pipeline.
+ * @return A list of stages.
+ */
+ public static List<Stage<?>> loadPipeline(String path, String expectedClassName)
+ throws IOException {
+ Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+ int numStages = (Integer) metadata.get("numStages");
+ List<Stage<?>> stages = new ArrayList<>(numStages);
+
+ for (int i = 0; i < numStages; i++) {
+ String stagePath = getPathForPipelineStage(i, numStages, path);
+ stages.add(loadStage(stagePath));
+ }
+ return stages;
+ }
+
+ // A helper method that sets stage's parameter value. We can not call stage.set(param, value)
+ // directly because stage::set(...) needs the actual type of the value.
+ public static <T> void setStageParam(Stage<?> stage, Param<T> param, Object value) {
+ stage.set(param, (T) value);
+ }
+
+ /**
+ * Loads the stage with the saved parameters from the given path. This method reads the metadata
+ * file under the given path, instantiates the stage using its no-argument constructor, and
+ * loads the stage with the paramMap from the metadata file.
+ *
+ * <p>Note: This method does not attempt to read model data from the given path. Caller needs to
+ * read model data from the given path if the stage has model data.
+ *
+ * <p>Required: the class with type T must have a no-argument constructor.
+ *
+ * @param path The parent directory of the stage metadata file.
+ * @param <T> The class type of the Stage subclass.
+ * @return An instance of class type T.
+ */
+ @SuppressWarnings("unchecked")
+ public static <T extends Stage<T>> T loadStageParam(String path) throws IOException {
+ Map<String, ?> metadata = loadMetadata(path, "");
+ String className = (String) metadata.get("className");
+ Map<String, String> paramMap = (Map<String, String>) metadata.get("paramMap");
+
+ try {
+ Class<T> clazz = (Class<T>) Class.forName(className);
+ T instance = InstantiationUtil.instantiate(clazz);
+
+ Map<String, Param<?>> nameToParam = new HashMap<>();
+ for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) {
+ nameToParam.put(param.name, param);
+ }
+
+ for (Map.Entry<String, String> entry : paramMap.entrySet()) {
+ Param<?> param = nameToParam.get(entry.getKey());
+ setStageParam(instance, param, param.jsonDecode(entry.getValue()));
+ }
+ return instance;
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Failed to load stage.", e);
+ }
+ }
+
+ /**
+ * Loads the stage from the given path by invoking the static load() method of the stage. The
+ * stage class name is read from the metadata file under the given path. The load() method is
+ * expected to construct the stage instance with the saved parameters, model data and other
+ * metadata if exists.
+ *
+ * <p>Required: the stage class must have a static load() method.
+ *
+ * @param path The parent directory of the stage metadata file.
+ * @return An instance of Stage.
+ */
+ public static Stage<?> loadStage(String path) throws IOException {
+ Map<String, ?> metadata = loadMetadata(path, "");
+ String className = (String) metadata.get("className");
+
+ try {
+ Class<?> clazz = Class.forName(className);
+ Method method = clazz.getMethod("load", String.class);
+ method.setAccessible(true);
+ return (Stage<?>) method.invoke(null, path);
+ } catch (NoSuchMethodException e) {
+ String methodName = String.format("%s::load(String)", className);
+ throw new RuntimeException(
+ "Failed to load stage because the static method "
+ + methodName
+ + " is not implemented.",
+ e);
+ } catch (ClassNotFoundException | IllegalAccessException | InvocationTargetException e) {
+ throw new RuntimeException("Failed to load stage.", e);
+ }
+ }
+}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
new file mode 100644
index 0000000..2e4b4c2
--- /dev/null
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.Map;
+
+/** Defines a few Stage subclasses to be used in unit tests. */
+public class ExampleStages {
+ /**
+ * A Model subclass that increments every value in the input stream by `delta` and outputs the
+ * resulting values.
+ */
+ public static class SumModel implements Model<SumModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private DataStream<Integer> modelData;
+
+ // This empty constructor is necessary in order for ModelA to be loaded by
+ // ReadWriteUtils.createStageWithParam
+ public SumModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Assert.assertEquals(1, inputs.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
+ DataStream<Integer> output =
+ input.connect(modelData.broadcast())
+ .transform(
+ "ApplyDeltaOperator",
+ BasicTypeInfo.INT_TYPE_INFO,
+ new ApplyDeltaOperator());
+
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public void setModelData(Table... inputs) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ modelData = tEnv.toDataStream(inputs[0], Integer.class);
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+
+ File dataDir = new File(path, "data");
+ if (!dataDir.mkdir()) {
+ throw new IOException("Directory " + dataDir.toString() + " already exists.");
+ }
+
+ File dataFile = new File(dataDir, "delta");
+ if (!dataFile.createNewFile()) {
+ throw new IOException("File " + dataFile.toString() + " already exists.");
+ }
+
+ int delta;
+ try {
+ delta = (Integer) IteratorUtils.toList(modelData.executeAndCollect()).get(0);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ try (DataOutputStream outputStream =
+ new DataOutputStream(new FileOutputStream(dataFile))) {
+ outputStream.writeInt(delta);
+ }
+ }
+
+ public static SumModel load(String path) throws IOException {
+ SumModel model = ReadWriteUtils.loadStageParam(path);
+ File dataFile = Paths.get(path, "data", "delta").toFile();
+
+ try (DataInputStream inputStream = new DataInputStream(new FileInputStream(dataFile))) {
+ StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
+ model.setModelData(modelData);
+ return model;
+ }
+ }
+ }
+
+ // Adds delta from the 2nd input to every element in the 1st input and returns the added values.
+ private static class ApplyDeltaOperator extends AbstractStreamOperator<Integer>
+ implements TwoInputStreamOperator<Integer, Integer, Integer> {
+ private ListState<Integer> unProcessedValues;
+ private BroadcastState<String, Integer> broadcastState = null;
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ unProcessedValues =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<Integer>(
+ "unProcessedValues", Integer.class));
+ broadcastState =
+ context.getOperatorStateStore()
+ .getBroadcastState(
+ new MapStateDescriptor<String, Integer>(
+ "broadcastState", String.class, Integer.class));
+ }
+
+ @Override
+ public void processElement1(StreamRecord<Integer> record) throws Exception {
+ if (broadcastState.get("delta") == null) {
+ unProcessedValues.add(record.getValue());
+ } else {
+ output.collect(new StreamRecord<>(record.getValue() + broadcastState.get("delta")));
+ }
+ }
+
+ @Override
+ public void processElement2(StreamRecord<Integer> record) throws Exception {
+ if (broadcastState.get("delta") != null) {
+ throw new IllegalStateException("Model data should have exactly one value");
+ }
+ broadcastState.put("delta", record.getValue());
+
+ for (Integer value : unProcessedValues.get()) {
+ output.collect(new StreamRecord<>(value + record.getValue()));
+ }
+ unProcessedValues.clear();
+ }
+ }
+
+ /**
+ * An Estimator subclass which calculates the sum of input values and instantiates a ModelA
+ * instance with delta=sum(inputs).
+ */
+ public static class SumEstimator implements Estimator<SumEstimator, SumModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public SumEstimator() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public SumModel fit(Table... inputs) {
+ Assert.assertEquals(1, inputs.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
+ DataStream<Integer> modelData =
+ input.transform("SumOperator", BasicTypeInfo.INT_TYPE_INFO, new SumOperator())
+ .setParallelism(1);
+ try {
+ SumModel model = new SumModel();
+ model.setModelData(tEnv.fromDataStream(modelData));
+
+ return model;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static SumEstimator load(String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+ }
+
+ private static class SumOperator extends AbstractStreamOperator<Integer>
+ implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {
+ int sum = 0;
+
+ @Override
+ public void endInput() throws Exception {
+ output.collect(new StreamRecord<>(sum));
+ }
+
+ @Override
+ public void processElement(StreamRecord<Integer> input) throws Exception {
+ sum += input.getValue();
+ }
+ }
+}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index 6d46430..74f9c65 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -18,141 +18,103 @@
package org.apache.flink.ml.api.core;
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.ml.api.core.ExampleStages.SumEstimator;
+import org.apache.flink.ml.api.core.ExampleStages.SumModel;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
-import org.junit.Rule;
+import org.apache.commons.collections.IteratorUtils;
import org.junit.Test;
-import org.junit.rules.ExpectedException;
-import java.io.IOException;
-import java.util.ArrayList;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
-/** Tests the behavior of {@link Pipeline}. */
-public class PipelineTest {
- @Rule public ExpectedException thrown = ExpectedException.none();
+/** Tests the behavior of Pipeline and PipelineModel. */
+public class PipelineTest extends AbstractTestBase {
- @Test
- public void testPipelineBehavior() {
- List<Stage<?>> stages = new ArrayList<>();
- stages.add(new MockTransformer("a"));
- stages.add(new MockEstimator("b"));
- stages.add(new MockEstimator("c"));
- stages.add(new MockTransformer("d"));
-
- Pipeline pipeline = new Pipeline(stages);
- assert describePipeline(pipeline.getStages()).equals("a_b_c_d");
-
- PipelineModel pipelineModel = pipeline.fit(null, null);
- assert describePipeline(pipelineModel.getStages()).equals("a_mb_mc_d");
- }
-
- private static String describePipeline(List<Stage<?>> stages) {
- StringBuilder res = new StringBuilder();
- for (Stage<?> s : stages) {
- if (res.length() != 0) {
- res.append("_");
- }
- res.append(((SelfDescribe) s).describe());
- }
- return res.toString();
- }
-
- /** Interface to describe a class with a string, only for pipeline test. */
- private interface SelfDescribe {
- ParamInfo<String> DESCRIPTION =
- ParamInfoFactory.createParamInfo("description", String.class).build();
-
- String describe();
- }
-
- /** Mock estimator for pipeline test. */
- public static class MockEstimator implements Estimator<MockEstimator, MockModel>, SelfDescribe {
- private final Params params = new Params();
+ // Executes the given stage and verifies that it produces the expected output.
+ private static void executeAndCheckOutput(
+ Stage<?> stage, List<Integer> input, List<Integer> expectedOutput) throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ env.setParallelism(4);
- public MockEstimator() {}
+ Table inputTable = tEnv.fromDataStream(env.fromCollection(input));
- MockEstimator(String description) {
- set(DESCRIPTION, description);
- }
-
- @Override
- public MockModel fit(Table... inputs) {
- return new MockModel("m" + describe());
- }
-
- @Override
- public Params getParams() {
- return params;
- }
+ Table outputTable;
- @Override
- public String describe() {
- return get(DESCRIPTION);
+ if (stage instanceof AlgoOperator) {
+ outputTable = ((AlgoOperator<?>) stage).transform(inputTable)[0];
+ } else {
+ Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
+ Model<?> model = estimator.fit(inputTable);
+ outputTable = model.transform(inputTable)[0];
}
- @Override
- public void save(String path) throws IOException {}
+ List<Integer> output =
+ IteratorUtils.toList(
+ tEnv.toDataStream(outputTable, Integer.class).executeAndCollect());
+ compareResultCollections(expectedOutput, output, Comparator.naturalOrder());
}
- /** Mock transformer for pipeline test. */
- public static class MockTransformer implements Transformer<MockTransformer>, SelfDescribe {
- private final Params params = new Params();
-
- public MockTransformer() {}
-
- MockTransformer(String description) {
- set(DESCRIPTION, description);
- }
-
- @Override
- public Table[] transform(Table... inputs) {
- return inputs;
- }
-
- @Override
- public Params getParams() {
- return params;
- }
-
- @Override
- public String describe() {
- return get(DESCRIPTION);
- }
-
- @Override
- public void save(String path) throws IOException {}
+ @Test
+ public void testPipelineModel() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ // Builds a PipelineModel that increments input value by 60. This PipelineModel consists of
+ // three stages where each stage increments input value by 10, 20, and 30 respectively.
+ SumModel modelA = new SumModel();
+ modelA.setModelData(tEnv.fromValues(10));
+ SumModel modelB = new SumModel();
+ modelB.setModelData(tEnv.fromValues(20));
+ SumModel modelC = new SumModel();
+ modelC.setModelData(tEnv.fromValues(30));
+
+ List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
+ Model<?> model = new PipelineModel(stages);
+
+ // Executes the original PipelineModel and verifies that it produces the expected output.
+ executeAndCheckOutput(model, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
+
+ // Saves and loads the PipelineModel.
+ Path tempDir = Files.createTempDirectory("PipelineTest");
+ String path = Paths.get(tempDir.toString(), "testPipelineModelSaveLoad").toString();
+ model.save(path);
+ Model<?> loadedModel = PipelineModel.load(path);
+
+ // Executes the loaded PipelineModel and verifies that it produces the expected output.
+ executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63));
}
- /** Mock model for pipeline test. */
- public static class MockModel implements Model<MockModel>, SelfDescribe {
- private final Params params = new Params();
-
- public MockModel() {}
-
- MockModel(String description) {
- set(DESCRIPTION, description);
- }
-
- @Override
- public Table[] transform(Table... inputs) {
- return inputs;
- }
-
- @Override
- public Params getParams() {
- return params;
- }
-
- @Override
- public String describe() {
- return get(DESCRIPTION);
- }
-
- @Override
- public void save(String path) throws IOException {}
+ @Test
+ public void testPipeline() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ // Builds a Pipeline that consists of a Model, an Estimator, and a model.
+ SumModel modelA = new SumModel();
+ modelA.setModelData(tEnv.fromValues(10));
+ SumModel modelB = new SumModel();
+ modelB.setModelData(tEnv.fromValues(30));
+
+ List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(), modelB);
+ Estimator<?, ?> estimator = new Pipeline(stages);
+
+ // Executes the original Pipeline and verifies that it produces the expected output.
+ executeAndCheckOutput(estimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
+
+ // Saves and loads the Pipeline.
+ Path tempDir = Files.createTempDirectory("PipelineTest");
+ String path = Paths.get(tempDir.toString(), "testPipeline").toString();
+ estimator.save(path);
+ Estimator<?, ?> loadedEstimator = Pipeline.load(path);
+
+ // Executes the loaded Pipeline and verifies that it produces the expected output.
+ executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79));
}
}
diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
new file mode 100644
index 0000000..9e03ddb
--- /dev/null
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java
@@ -0,0 +1,375 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.FloatArrayParam;
+import org.apache.flink.ml.param.FloatParam;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.LongArrayParam;
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** Tests the behavior of Stage and WithParams. */
+public class StageTest {
+
+ // A WithParams subclass which has one parameter for each pre-defined parameter type.
+ private interface MyParams<T> extends WithParams<T> {
+ Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);
+
+ Param<Integer> INT_PARAM =
+ new IntParam("intParam", "Description", 1, ParamValidators.lt(100));
+
+ Param<Long> LONG_PARAM =
+ new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));
+
+ Param<Float> FLOAT_PARAM =
+ new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100));
+
+ Param<Double> DOUBLE_PARAM =
+ new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100));
+
+ Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5");
+
+ Param<Integer[]> INT_ARRAY_PARAM =
+ new IntArrayParam("intArrayParam", "Description", new Integer[] {6, 7});
+
+ Param<Long[]> LONG_ARRAY_PARAM =
+ new LongArrayParam(
+ "longArrayParam",
+ "Description",
+ new Long[] {8L, 9L},
+ ParamValidators.alwaysTrue());
+
+ Param<Float[]> FLOAT_ARRAY_PARAM =
+ new FloatArrayParam("floatArrayParam", "Description", new Float[] {10.0f, 11.0f});
+
+ Param<Double[]> DOUBLE_ARRAY_PARAM =
+ new DoubleArrayParam(
+ "doubleArrayParam",
+ "Description",
+ new Double[] {12.0, 13.0},
+ ParamValidators.alwaysTrue());
+
+ Param<String[]> STRING_ARRAY_PARAM =
+ new StringArrayParam("stringArrayParam", "Description", new String[] {"14", "15"});
+ }
+
+ /**
+ * A Stage subclass which inherits all parameters from MyParams and defines an extra parameter.
+ */
+ public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public final Param<Integer> extraIntParam =
+ new IntParam("extraIntParam", "Description", 20, ParamValidators.alwaysTrue());
+
+ public final Param<Integer> paramWithNullDefault =
+ new IntParam(
+ "paramWithNullDefault",
+ "Must be explicitly set with a non-null value",
+ null,
+ ParamValidators.notNull());
+
+ public MyStage() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static MyStage load(String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+ }
+
+ /** A Stage subclass without the static load() method. */
+ public static class MyStageWithoutLoad implements Stage<MyStage>, MyParams<MyStage> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public MyStageWithoutLoad() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+ }
+
+ // Asserts that m1 and m2 are equivalent.
+ private static void assertParamMapEquals(Map<Param<?>, Object> m1, Map<Param<?>, Object> m2) {
+ Assert.assertTrue(m1 != null && m2 != null);
+ Assert.assertEquals(m1.size(), m2.size());
+
+ for (Map.Entry<Param<?>, Object> entry : m1.entrySet()) {
+ Assert.assertTrue(m2.containsKey(entry.getKey()));
+ Object v1 = entry.getValue();
+ Object v2 = m2.get(entry.getKey());
+ if (v1 == null || v2 == null) {
+ Assert.assertTrue(v1 == null && v2 == null);
+ } else if (v1.getClass().isArray() && v2.getClass().isArray()) {
+ Assert.assertArrayEquals((Object[]) v1, (Object[]) v2);
+ } else {
+ Assert.assertEquals(v1, v2);
+ }
+ }
+ }
+
+ // Saves and loads the given stage. And verifies that the loaded stage has same parameter values
+ // as the original stage.
+ private static Stage<?> validateStageSaveLoad(
+ Stage<?> stage, Map<String, Object> paramOverrides) throws IOException {
+ for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
+ Param<?> param = stage.getParam(entry.getKey());
+ ReadWriteUtils.setStageParam(stage, param, entry.getValue());
+ }
+
+ String tempDir = Files.createTempDirectory("").toString();
+ String path = Paths.get(tempDir, "test").toString();
+ stage.save(path);
+ try {
+ stage.save(path);
+ Assert.fail("Expected IOException");
+ } catch (IOException e) {
+ // This is expected.
+ }
+
+ Stage<?> loadedStage = ReadWriteUtils.loadStage(path);
+ for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
+ Param<?> param = loadedStage.getParam(entry.getKey());
+ Assert.assertEquals(entry.getValue(), loadedStage.get(param));
+ }
+ assertParamMapEquals(stage.getParamMap(), loadedStage.getParamMap());
+ return loadedStage;
+ }
+
+ @Test
+ public void testParamSetValueWithName() {
+ MyStage stage = new MyStage();
+
+ Param<Integer> paramA = MyParams.INT_PARAM;
+ stage.set(paramA, 2);
+ Assert.assertEquals(2, (int) stage.get(paramA));
+
+ Param<Integer> paramB = stage.getParam("intParam");
+ stage.set(paramB, 3);
+ Assert.assertEquals(3, (int) stage.get(paramB));
+
+ Param<Integer> paramC = stage.getParam("extraIntParam");
+ stage.set(paramC, 50);
+ Assert.assertEquals(50, (int) stage.get(paramC));
+ }
+
+ @Test
+ public void testParamWithNullDefault() {
+ MyStage stage = new MyStage();
+ try {
+ stage.get(stage.paramWithNullDefault);
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ Assert.assertTrue(e.getMessage().contains("should not be null"));
+ }
+
+ stage.set(stage.paramWithNullDefault, 3);
+ Assert.assertEquals(3, (int) stage.get(stage.paramWithNullDefault));
+ }
+
+ private static <T> void assertInvalidValue(Stage<?> stage, Param<T> param, T value) {
+ try {
+ stage.set(param, value);
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ Assert.assertTrue(e.getMessage().contains("invalid value"));
+ }
+ }
+
+ private static <T> void assertInvalidClass(Stage<?> stage, Param<T> param, Object value) {
+ try {
+ stage.set(param, (T) value);
+ Assert.fail("Expected ClassCastException");
+ } catch (ClassCastException e) {
+ Assert.assertTrue(e.getMessage().contains("incompatible class"));
+ }
+ }
+
+ @Test
+ public void testParamSetInvalidValue() {
+ MyStage stage = new MyStage();
+ assertInvalidValue(stage, MyParams.INT_PARAM, 100);
+ assertInvalidValue(stage, MyParams.LONG_PARAM, 100L);
+ assertInvalidValue(stage, MyParams.FLOAT_PARAM, 100.0f);
+ assertInvalidValue(stage, MyParams.DOUBLE_PARAM, 100.0);
+ assertInvalidClass(stage, MyParams.INT_PARAM, "100");
+ assertInvalidClass(stage, MyParams.STRING_PARAM, 100);
+
+ Param<Integer> param = stage.getParam("stringParam");
+ assertInvalidClass(stage, param, 50);
+ }
+
+ @Test
+ public void testParamSetValidValue() {
+ MyStage stage = new MyStage();
+
+ stage.set(MyParams.BOOLEAN_PARAM, true);
+ Assert.assertEquals(true, stage.get(MyParams.BOOLEAN_PARAM));
+
+ stage.set(MyParams.INT_PARAM, 50);
+ Assert.assertEquals(50, (int) stage.get(MyParams.INT_PARAM));
+
+ stage.set(MyParams.LONG_PARAM, 50L);
+ Assert.assertEquals(50L, (long) stage.get(MyParams.LONG_PARAM));
+
+ stage.set(MyParams.FLOAT_PARAM, 50f);
+ Assert.assertEquals(50f, (float) stage.get(MyParams.FLOAT_PARAM), 0.0001);
+
+ stage.set(MyParams.DOUBLE_PARAM, 50.0);
+ Assert.assertEquals(50, (double) stage.get(MyParams.DOUBLE_PARAM), 0.0001);
+
+ stage.set(MyParams.STRING_PARAM, "50");
+ Assert.assertEquals("50", stage.get(MyParams.STRING_PARAM));
+
+ stage.set(MyParams.INT_ARRAY_PARAM, new Integer[] {50, 51});
+ Assert.assertArrayEquals(new Integer[] {50, 51}, stage.get(MyParams.INT_ARRAY_PARAM));
+
+ stage.set(MyParams.LONG_ARRAY_PARAM, new Long[] {50L, 51L});
+ Assert.assertArrayEquals(new Long[] {50L, 51L}, stage.get(MyParams.LONG_ARRAY_PARAM));
+
+ stage.set(MyParams.FLOAT_ARRAY_PARAM, new Float[] {50.0f, 51.0f});
+ Assert.assertArrayEquals(new Float[] {50.0f, 51.0f}, stage.get(MyParams.FLOAT_ARRAY_PARAM));
+
+ stage.set(MyParams.DOUBLE_ARRAY_PARAM, new Double[] {50.0, 51.0});
+ Assert.assertArrayEquals(new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_PARAM));
+
+ stage.set(MyParams.STRING_ARRAY_PARAM, new String[] {"50", "51"});
+ Assert.assertArrayEquals(new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_PARAM));
+ }
+
+ @Test
+ public void testStageSaveLoad() throws IOException {
+ MyStage stage = new MyStage();
+ stage.set(stage.paramWithNullDefault, 1);
+ Stage<?> loadedStage = validateStageSaveLoad(stage, Collections.emptyMap());
+ Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM));
+ }
+
+ @Test
+ public void testStageSaveLoadWithParamOverrides() throws IOException {
+ MyStage stage = new MyStage();
+ stage.set(stage.paramWithNullDefault, 1);
+ Stage<?> loadedStage =
+ validateStageSaveLoad(stage, Collections.singletonMap("intParam", 10));
+ Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM));
+ }
+
+ @Test
+ public void testStageLoadWithoutLoadMethod() throws IOException {
+ MyStageWithoutLoad stage = new MyStageWithoutLoad();
+ try {
+ validateStageSaveLoad(stage, Collections.emptyMap());
+ Assert.fail("Expected RuntimeException");
+ } catch (RuntimeException e) {
+ Assert.assertTrue(e.getMessage().contains("not implemented"));
+ }
+ }
+
+ @Test
+ public void testValidators() {
+ ParamValidator<Integer> gt = ParamValidators.gt(10);
+ Assert.assertFalse(gt.validate(null));
+ Assert.assertFalse(gt.validate(5));
+ Assert.assertFalse(gt.validate(10));
+ Assert.assertTrue(gt.validate(15));
+
+ ParamValidator<Integer> gtEq = ParamValidators.gtEq(10);
+ Assert.assertFalse(gtEq.validate(null));
+ Assert.assertFalse(gtEq.validate(5));
+ Assert.assertTrue(gtEq.validate(10));
+ Assert.assertTrue(gtEq.validate(15));
+
+ ParamValidator<Integer> lt = ParamValidators.lt(10);
+ Assert.assertFalse(lt.validate(null));
+ Assert.assertTrue(lt.validate(5));
+ Assert.assertFalse(lt.validate(10));
+ Assert.assertFalse(lt.validate(15));
+
+ ParamValidator<Integer> ltEq = ParamValidators.ltEq(10);
+ Assert.assertFalse(ltEq.validate(null));
+ Assert.assertTrue(ltEq.validate(5));
+ Assert.assertTrue(ltEq.validate(10));
+ Assert.assertFalse(ltEq.validate(15));
+
+ ParamValidator<Integer> inRangeInclusive = ParamValidators.inRange(5, 15);
+ Assert.assertFalse(inRangeInclusive.validate(null));
+ Assert.assertFalse(inRangeInclusive.validate(0));
+ Assert.assertTrue(inRangeInclusive.validate(5));
+ Assert.assertTrue(inRangeInclusive.validate(10));
+ Assert.assertTrue(inRangeInclusive.validate(15));
+ Assert.assertFalse(inRangeInclusive.validate(20));
+
+ ParamValidator<Integer> inRangeExclusive = ParamValidators.inRange(5, 15, false, false);
+ Assert.assertFalse(inRangeExclusive.validate(null));
+ Assert.assertFalse(inRangeExclusive.validate(0));
+ Assert.assertFalse(inRangeExclusive.validate(5));
+ Assert.assertTrue(inRangeExclusive.validate(10));
+ Assert.assertFalse(inRangeExclusive.validate(15));
+ Assert.assertFalse(inRangeExclusive.validate(20));
+
+ ParamValidator<Integer> inArray = ParamValidators.inArray(1, 2, 3);
+ Assert.assertFalse(inArray.validate(null));
+ Assert.assertTrue(inArray.validate(1));
+ Assert.assertFalse(inArray.validate(0));
+
+ ParamValidator<Integer> notNull = ParamValidators.notNull();
+ Assert.assertTrue(notNull.validate(5));
+ Assert.assertFalse(notNull.validate(null));
+ }
+}
diff --git a/pom.xml b/pom.xml
index 5eb1805..66530a0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -53,8 +53,6 @@ under the License.
<modules>
<module>flink-ml-api</module>
- <module>flink-ml-lib</module>
- <module>flink-ml-uber</module>
<module>flink-ml-iteration</module>
<module>flink-ml-tests</module>
</modules>