You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/11/07 08:03:36 UTC
[flink-ml] branch master updated: [FLINK-29434] Add AlgoOperator for RandomSplitter
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 0aad2f9 [FLINK-29434] Add AlgoOperator for RandomSplitter
0aad2f9 is described below
commit 0aad2f942887e0513093106faf9680bb9de8913e
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Mon Nov 7 16:03:30 2022 +0800
[FLINK-29434] Add AlgoOperator for RandomSplitter
This closes #160.
---
.../docs/operators/feature/randomsplitter.md | 147 +++++++++++++++++++++
.../ml/examples/feature/RandomSplitterExample.java | 65 +++++++++
.../ml/feature/randomsplitter/RandomSplitter.java | 129 ++++++++++++++++++
.../randomsplitter/RandomSplitterParams.java | 65 +++++++++
.../flink/ml/feature/RandomSplitterTest.java | 137 +++++++++++++++++++
.../examples/ml/feature/randomsplitter_example.py | 63 +++++++++
.../pyflink/ml/lib/feature/normalizer.py | 2 +-
.../pyflink/ml/lib/feature/randomsplitter.py | 80 +++++++++++
.../ml/lib/feature/tests/test_normalizer.py | 1 +
.../ml/lib/feature/tests/test_randomsplitter.py | 77 +++++++++++
10 files changed, 765 insertions(+), 1 deletion(-)
diff --git a/docs/content/docs/operators/feature/randomsplitter.md b/docs/content/docs/operators/feature/randomsplitter.md
new file mode 100644
index 0000000..3fb2c22
--- /dev/null
+++ b/docs/content/docs/operators/feature/randomsplitter.md
@@ -0,0 +1,147 @@
+---
+title: "RandomSplitter"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/randomsplitter.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## RandomSplitter
+
+An AlgoOperator which splits a table into N tables according to the given weights.
+
+### Parameters
+
+| Key | Default | Type | Required | Description |
+|:--------|:-------------|:---------|:---------|:-------------------------------|
+| weights | `[1.0, 1.0]` | Double[] | no | The weights of data splitting. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a RandomSplitter instance and uses it for data splitting. */
+public class RandomSplitterExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(1, 10, 0),
+ Row.of(1, 10, 0),
+ Row.of(1, 10, 0),
+ Row.of(4, 10, 0),
+ Row.of(5, 10, 0),
+ Row.of(6, 10, 0),
+ Row.of(7, 10, 0),
+ Row.of(10, 10, 0),
+ Row.of(13, 10, 3));
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates a RandomSplitter object and initializes its parameters.
+ RandomSplitter splitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+ // Uses the RandomSplitter to split inputData.
+ Table[] outputTables = splitter.transform(inputTable);
+
+ // Extracts and displays the results.
+ System.out.println("Split Result 1 (40%)");
+ for (CloseableIterator<Row> it = outputTables[0].execute().collect(); it.hasNext(); ) {
+ System.out.printf("%s\n", it.next());
+ }
+ System.out.println("Split Result 2 (60%)");
+ for (CloseableIterator<Row> it = outputTables[1].execute().collect(); it.hasNext(); ) {
+ System.out.printf("%s\n", it.next());
+ }
+ }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a RandomSplitter instance and uses it for data splitting.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (1, 10, 0),
+ (1, 10, 0),
+ (1, 10, 0),
+ (4, 10, 0),
+ (5, 10, 0),
+ (6, 10, 0),
+ (7, 10, 0),
+ (10, 10, 0),
+ (13, 10, 0)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['f0', 'f1', "f2"],
+ [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a RandomSplitter object and initializes its parameters.
+splitter = RandomSplitter().set_weights(4.0, 6.0)
+
+# Uses the RandomSplitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+print("Split Result 1 (40%)")
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+ print(str(result))
+
+print("Split Result 2 (60%)")
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+ print(str(result))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java
new file mode 100644
index 0000000..eb3ad82
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a RandomSplitter instance and uses it for data splitting. */
+public class RandomSplitterExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(1, 10, 0),
+ Row.of(1, 10, 0),
+ Row.of(1, 10, 0),
+ Row.of(4, 10, 0),
+ Row.of(5, 10, 0),
+ Row.of(6, 10, 0),
+ Row.of(7, 10, 0),
+ Row.of(10, 10, 0),
+ Row.of(13, 10, 3));
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates a RandomSplitter object and initializes its parameters.
+ RandomSplitter splitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+ // Uses the RandomSplitter to split inputData.
+ Table[] outputTable = splitter.transform(inputTable);
+
+ // Extracts and displays the results.
+ System.out.println("Split Result 1 (40%)");
+ for (CloseableIterator<Row> it = outputTable[0].execute().collect(); it.hasNext(); ) {
+ System.out.printf("%s\n", it.next());
+ }
+ System.out.println("Split Result 2 (60%)");
+ for (CloseableIterator<Row> it = outputTable[1].execute().collect(); it.hasNext(); ) {
+ System.out.printf("%s\n", it.next());
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java
new file mode 100644
index 0000000..a3c7b46
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a Table into N Tables according to the given weights. */
+public class RandomSplitter
+ implements AlgoOperator<RandomSplitter>, RandomSplitterParams<RandomSplitter> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public RandomSplitter() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ RowTypeInfo outputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+ final Double[] weights = getWeights();
+ OutputTag<Row>[] outputTags = new OutputTag[weights.length - 1];
+ for (int i = 0; i < outputTags.length; ++i) {
+ outputTags[i] = new OutputTag<Row>("outputTag_" + i, outputTypeInfo) {};
+ }
+
+ SingleOutputStreamOperator<Row> results =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "SplitterOperator",
+ outputTypeInfo,
+ new SplitterOperator(outputTags, weights));
+
+ Table[] outputTables = new Table[weights.length];
+ outputTables[0] = tEnv.fromDataStream(results);
+
+ for (int i = 0; i < outputTags.length; ++i) {
+ DataStream<Row> dataStream = results.getSideOutput(outputTags[i]);
+ outputTables[i + 1] = tEnv.fromDataStream(dataStream);
+ }
+ return outputTables;
+ }
+
+ private static class SplitterOperator extends AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row> {
+ private final Random random = new Random(0);
+ OutputTag<Row>[] outputTag;
+ final double[] fractions;
+
+ public SplitterOperator(OutputTag<Row>[] outputTag, Double[] weights) {
+ this.outputTag = outputTag;
+ this.fractions = new double[weights.length];
+ double weightSum = 0.0;
+ for (Double weight : weights) {
+ weightSum += weight;
+ }
+ double currentSum = 0.0;
+ for (int i = 0; i < fractions.length; ++i) {
+ currentSum += weights[i];
+ fractions[i] = currentSum / weightSum;
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> streamRecord) throws Exception {
+ int searchResult = Arrays.binarySearch(fractions, random.nextDouble());
+ int index = searchResult < 0 ? -searchResult - 2 : searchResult - 1;
+ if (index == -1) {
+ output.collect(streamRecord);
+ } else {
+ output.collect(outputTag[index], streamRecord);
+ }
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static RandomSplitter load(StreamTableEnvironment env, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java
new file mode 100644
index 0000000..4095999
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link RandomSplitter}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface RandomSplitterParams<T> extends WithParams<T> {
+ /**
+ * Weights should be a non-empty array with all elements greater than zero. The weights will be
+ * normalized such that the sum of all elements equals to one.
+ */
+ Param<Double[]> WEIGHTS =
+ new DoubleArrayParam(
+ "weights",
+ "The weights of data splitting.",
+ new Double[] {1.0, 1.0},
+ weightsValidator());
+
+ default Double[] getWeights() {
+ return get(WEIGHTS);
+ }
+
+ default T setWeights(Double... value) {
+ return set(WEIGHTS, value);
+ }
+
+ // Checks the weights parameter.
+ static ParamValidator<Double[]> weightsValidator() {
+ return weights -> {
+ if (weights == null) {
+ return false;
+ }
+ for (Double weight : weights) {
+ if (weight <= 0.0) {
+ return false;
+ }
+ }
+ return weights.length > 1;
+ };
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
new file mode 100644
index 0000000..aeb0245
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link RandomSplitter}. */
+public class RandomSplitterTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(1);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+
+ tEnv = StreamTableEnvironment.create(env);
+ }
+
+ private Table getTable(int size) {
+ DataStreamSource<Long> dataStream = env.fromSequence(0L, size);
+ return tEnv.fromDataStream(dataStream);
+ }
+
+ @Test
+ public void testParam() {
+ RandomSplitter splitter = new RandomSplitter();
+ splitter.setWeights(0.3, 0.4);
+ assertArrayEquals(new Double[] {0.3, 0.4}, splitter.getWeights());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable =
+ tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+ .as("test_input", "dummy_input");
+
+ RandomSplitter splitter = new RandomSplitter().setWeights(0.5, 0.1);
+ Table[] output = splitter.transform(tempTable);
+ assertEquals(2, output.length);
+ for (Table table : output) {
+ assertEquals(
+ Arrays.asList("test_input", "dummy_input"),
+ table.getResolvedSchema().getColumnNames());
+ }
+ }
+
+ @Test
+ public void testWeights() throws Exception {
+ Table data = getTable(1000);
+ RandomSplitter splitter = new RandomSplitter().setWeights(2.0, 1.0, 2.0);
+ Table[] output = splitter.transform(data);
+
+ List<Row> result0 = IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
+ List<Row> result1 = IteratorUtils.toList(tEnv.toDataStream(output[1]).executeAndCollect());
+ List<Row> result2 = IteratorUtils.toList(tEnv.toDataStream(output[2]).executeAndCollect());
+ assertEquals(result0.size() / 400.0, 1.0, 0.1);
+ assertEquals(result1.size() / 200.0, 1.0, 0.1);
+ assertEquals(result2.size() / 400.0, 1.0, 0.1);
+ verifyResultTables(data, output);
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ Table data = getTable(2000);
+ RandomSplitter randomSplitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+ RandomSplitter splitterLoad =
+ TestUtils.saveAndReload(
+ tEnv, randomSplitter, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+ Table[] output = splitterLoad.transform(data);
+ List<Row> result0 = IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
+ List<Row> result1 = IteratorUtils.toList(tEnv.toDataStream(output[1]).executeAndCollect());
+ assertEquals(result0.size() / 800.0, 1.0, 0.1);
+ assertEquals(result1.size() / 1200.0, 1.0, 0.1);
+ verifyResultTables(data, output);
+ }
+
+ private void verifyResultTables(Table input, Table[] output) throws Exception {
+ List<Row> expectedData = IteratorUtils.toList(tEnv.toDataStream(input).executeAndCollect());
+ List<Row> results = new ArrayList<>();
+ for (Table table : output) {
+ List<Row> result = IteratorUtils.toList(tEnv.toDataStream(table).executeAndCollect());
+ results.addAll(result);
+ }
+ assertEquals(expectedData.size(), results.size());
+ compareResultCollections(
+ expectedData, results, Comparator.comparingLong(row -> row.getFieldAs(0)));
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py b/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py
new file mode 100644
index 0000000..9cdabc9
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py
@@ -0,0 +1,63 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a RandomSplitter instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (1, 10, 0),
+ (1, 10, 0),
+ (1, 10, 0),
+ (4, 10, 0),
+ (5, 10, 0),
+ (6, 10, 0),
+ (7, 10, 0),
+ (10, 10, 0),
+ (13, 10, 0)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['f0', 'f1', "f2"],
+ [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a RandomSplitter object and initializes its parameters.
+splitter = RandomSplitter().set_weights(4.0, 6.0)
+
+# Uses the RandomSplitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+print("Split Result 1 (40%)")
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+ print(str(result))
+
+print("Split Result 2 (60%)")
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+ print(str(result))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/normalizer.py b/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
index 53c4e6b..2b99ee2 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
@@ -44,7 +44,7 @@ class _NormalizerParams(
def set_p(self, value: float):
return typing.cast(_NormalizerParams, self.set(self.P, value))
- def get_p(self) -> bool:
+ def get_p(self) -> float:
return self.get(self.P)
@property
diff --git a/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py b/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py
new file mode 100644
index 0000000..4c23891
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py
@@ -0,0 +1,80 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+from typing import Tuple
+from pyflink.ml.core.param import Param, FloatArrayParam, ParamValidator
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+
+
+class _RandomSplitterParams(
+ JavaWithParams
+):
+ """
+ Checks the weights parameter.
+ """
+ def weights_validator(self) -> ParamValidator[Tuple[float]]:
+ class WeightsValidator(ParamValidator[Tuple[float]]):
+ def validate(self, weights: Tuple[float]) -> bool:
+ for val in weights:
+ if val <= 0.0:
+ return False
+ return len(weights) > 1
+ return WeightsValidator()
+
+ """
+ Params for :class:`RandomSplitter`.
+ Weights should be a non-empty array with all elements greater than zero.
+ The weights will be normalized such that the sum of all elements equals
+ to one.
+ """
+ WEIGHTS: Param[Tuple[float]] = FloatArrayParam(
+ "weights",
+ "The weights of data splitting.",
+ [1.0, 1.0],
+ weights_validator(None))
+
+ def __init__(self, java_params):
+ super(_RandomSplitterParams, self).__init__(java_params)
+
+ def set_weights(self, *value: float):
+ return typing.cast(_RandomSplitterParams, self.set(self.WEIGHTS, value))
+
+ def get_weights(self) -> Tuple[float, ...]:
+ return self.get(self.WEIGHTS)
+
+ @property
+ def weights(self):
+ return self.get_weights()
+
+
+class RandomSplitter(JavaFeatureTransformer, _RandomSplitterParams):
+ """
+ An AlgoOperator which splits a table into N tables according to the given weights.
+ """
+
+ def __init__(self, java_model=None):
+ super(RandomSplitter, self).__init__(java_model)
+
+ @classmethod
+ def _java_transformer_package_name(cls) -> str:
+ return "randomsplitter"
+
+ @classmethod
+ def _java_transformer_class_name(cls) -> str:
+ return "RandomSplitter"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
index f765a2f..f8d07d7 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
@@ -64,6 +64,7 @@ class NormalizerTest(PyFlinkMLTestCase):
self.assertEqual("intput_vec", normalizer.get_input_col())
self.assertEqual(1.5, normalizer.get_p())
+ self.assertEqual(float, type(normalizer.get_p()))
self.assertEqual('output_vec', normalizer.get_output_col())
def test_save_load_transform(self):
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py
new file mode 100644
index 0000000..bcf99c7
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py
@@ -0,0 +1,77 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class RandomSplitterTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(RandomSplitterTest, self).setUp()
+ data = []
+ for i in range(1, 10000):
+ data.append((i, ))
+ self.input_table = self.t_env.from_data_stream(
+ self.env.from_collection(
+ data,
+ type_info=Types.ROW_NAMED(
+ ['f0', ],
+ [Types.INT(), ])))
+
+ def test_param(self):
+ splitter = RandomSplitter()
+ splitter.set_weights(0.2, 0.8)
+ self.assertEqual(0.2, splitter.weights[0])
+ self.assertEqual(0.8, splitter.weights[1])
+
+ def test_output_schema(self):
+ splitter = RandomSplitter()
+ input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ ('', ''),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['test_input', 'dummy_input'],
+ [Types.STRING(), Types.STRING()])))
+ output = splitter.set_weights(0.5, 0.5) \
+ .transform(input_data_table)[0]
+
+ self.assertEqual(
+ ['test_input', 'dummy_input'],
+ output.get_schema().get_field_names())
+
+ def test_transform(self):
+ splitter = RandomSplitter().set_weights(0.4, 0.6)
+
+ output = splitter.transform(self.input_table)
+ results = [result for result in self.t_env.to_data_stream(output[0]).execute_and_collect()]
+ self.assertAlmostEqual(len(results) / 4000.0, 1.0, delta=0.1)
+
+ def test_save_load_transform(self):
+ splitter = RandomSplitter().set_weights(0.4, 0.6)
+ path = os.path.join(self.temp_dir, 'test_save_load_random_splitter')
+ splitter.save(path)
+ splitter = RandomSplitter.load(self.t_env, path)
+
+ output = splitter.transform(self.input_table)
+ results = [result for result in self.t_env.to_data_stream(output[0]).execute_and_collect()]
+ self.assertAlmostEqual(len(results) / 4000.0, 1.0, delta=0.1)