You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/11/07 08:03:36 UTC

[flink-ml] branch master updated: [FLINK-29434] Add AlgoOperator for RandomSplitter

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 0aad2f9  [FLINK-29434] Add AlgoOperator for RandomSplitter
0aad2f9 is described below

commit 0aad2f942887e0513093106faf9680bb9de8913e
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Mon Nov 7 16:03:30 2022 +0800

    [FLINK-29434] Add AlgoOperator for RandomSplitter
    
    This closes #160.
---
 .../docs/operators/feature/randomsplitter.md       | 147 +++++++++++++++++++++
 .../ml/examples/feature/RandomSplitterExample.java |  65 +++++++++
 .../ml/feature/randomsplitter/RandomSplitter.java  | 129 ++++++++++++++++++
 .../randomsplitter/RandomSplitterParams.java       |  65 +++++++++
 .../flink/ml/feature/RandomSplitterTest.java       | 137 +++++++++++++++++++
 .../examples/ml/feature/randomsplitter_example.py  |  63 +++++++++
 .../pyflink/ml/lib/feature/normalizer.py           |   2 +-
 .../pyflink/ml/lib/feature/randomsplitter.py       |  80 +++++++++++
 .../ml/lib/feature/tests/test_normalizer.py        |   1 +
 .../ml/lib/feature/tests/test_randomsplitter.py    |  77 +++++++++++
 10 files changed, 765 insertions(+), 1 deletion(-)

diff --git a/docs/content/docs/operators/feature/randomsplitter.md b/docs/content/docs/operators/feature/randomsplitter.md
new file mode 100644
index 0000000..3fb2c22
--- /dev/null
+++ b/docs/content/docs/operators/feature/randomsplitter.md
@@ -0,0 +1,147 @@
+---
+title: "RandomSplitter"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/randomsplitter.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## RandomSplitter
+
+An AlgoOperator which splits a table into N tables according to the given weights.
+
+### Parameters
+
+| Key     | Default      | Type     | Required | Description                    |
+|:--------|:-------------|:---------|:---------|:-------------------------------|
+| weights | `[1.0, 1.0]` | Double[] | no       | The weights of data splitting. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a RandomSplitter instance and uses it for data splitting. */
+public class RandomSplitterExample {
+	public static void main(String[] args) {
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+		// Generates input data.
+		DataStream<Row> inputStream =
+			env.fromElements(
+				Row.of(1, 10, 0),
+				Row.of(1, 10, 0),
+				Row.of(1, 10, 0),
+				Row.of(4, 10, 0),
+				Row.of(5, 10, 0),
+				Row.of(6, 10, 0),
+				Row.of(7, 10, 0),
+				Row.of(10, 10, 0),
+				Row.of(13, 10, 3));
+		Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+		// Creates a RandomSplitter object and initializes its parameters.
+		RandomSplitter splitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+		// Uses the RandomSplitter to split inputData.
+		Table[] outputTables = splitter.transform(inputTable);
+
+		// Extracts and displays the results.
+		System.out.println("Split Result 1 (40%)");
+		for (CloseableIterator<Row> it = outputTables[0].execute().collect(); it.hasNext(); ) {
+			System.out.printf("%s\n", it.next());
+		}
+		System.out.println("Split Result 2 (60%)");
+		for (CloseableIterator<Row> it = outputTables[1].execute().collect(); it.hasNext(); ) {
+			System.out.printf("%s\n", it.next());
+		}
+	}
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a RandomSplitter instance and uses it for data splitting.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, 10, 0),
+        (1, 10, 0),
+        (1, 10, 0),
+        (4, 10, 0),
+        (5, 10, 0),
+        (6, 10, 0),
+        (7, 10, 0),
+        (10, 10, 0),
+        (13, 10, 0)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['f0', 'f1', "f2"],
+            [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a RandomSplitter object and initializes its parameters.
+splitter = RandomSplitter().set_weights(4.0, 6.0)
+
+# Uses the RandomSplitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+print("Split Result 1 (40%)")
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+    print(str(result))
+
+print("Split Result 2 (60%)")
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+    print(str(result))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java
new file mode 100644
index 0000000..eb3ad82
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RandomSplitterExample.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a RandomSplitter instance and uses it for data splitting. */
+public class RandomSplitterExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(1, 10, 0),
+                        Row.of(1, 10, 0),
+                        Row.of(1, 10, 0),
+                        Row.of(4, 10, 0),
+                        Row.of(5, 10, 0),
+                        Row.of(6, 10, 0),
+                        Row.of(7, 10, 0),
+                        Row.of(10, 10, 0),
+                        Row.of(13, 10, 3));
+        Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+        // Creates a RandomSplitter object and initializes its parameters.
+        RandomSplitter splitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+        // Uses the RandomSplitter to split inputData.
+        Table[] outputTable = splitter.transform(inputTable);
+
+        // Extracts and displays the results.
+        System.out.println("Split Result 1 (40%)");
+        for (CloseableIterator<Row> it = outputTable[0].execute().collect(); it.hasNext(); ) {
+            System.out.printf("%s\n", it.next());
+        }
+        System.out.println("Split Result 2 (60%)");
+        for (CloseableIterator<Row> it = outputTable[1].execute().collect(); it.hasNext(); ) {
+            System.out.printf("%s\n", it.next());
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java
new file mode 100644
index 0000000..a3c7b46
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a Table into N Tables according to the given weights. */
+public class RandomSplitter
+        implements AlgoOperator<RandomSplitter>, RandomSplitterParams<RandomSplitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public RandomSplitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        final Double[] weights = getWeights();
+        OutputTag<Row>[] outputTags = new OutputTag[weights.length - 1];
+        for (int i = 0; i < outputTags.length; ++i) {
+            outputTags[i] = new OutputTag<Row>("outputTag_" + i, outputTypeInfo) {};
+        }
+
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",
+                                outputTypeInfo,
+                                new SplitterOperator(outputTags, weights));
+
+        Table[] outputTables = new Table[weights.length];
+        outputTables[0] = tEnv.fromDataStream(results);
+
+        for (int i = 0; i < outputTags.length; ++i) {
+            DataStream<Row> dataStream = results.getSideOutput(outputTags[i]);
+            outputTables[i + 1] = tEnv.fromDataStream(dataStream);
+        }
+        return outputTables;
+    }
+
+    private static class SplitterOperator extends AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row> {
+        private final Random random = new Random(0);
+        OutputTag<Row>[] outputTag;
+        final double[] fractions;
+
+        public SplitterOperator(OutputTag<Row>[] outputTag, Double[] weights) {
+            this.outputTag = outputTag;
+            this.fractions = new double[weights.length];
+            double weightSum = 0.0;
+            for (Double weight : weights) {
+                weightSum += weight;
+            }
+            double currentSum = 0.0;
+            for (int i = 0; i < fractions.length; ++i) {
+                currentSum += weights[i];
+                fractions[i] = currentSum / weightSum;
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws Exception {
+            int searchResult = Arrays.binarySearch(fractions, random.nextDouble());
+            int index = searchResult < 0 ? -searchResult - 2 : searchResult - 1;
+            if (index == -1) {
+                output.collect(streamRecord);
+            } else {
+                output.collect(outputTag[index], streamRecord);
+            }
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static RandomSplitter load(StreamTableEnvironment env, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java
new file mode 100644
index 0000000..4095999
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link RandomSplitter}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface RandomSplitterParams<T> extends WithParams<T> {
+    /**
+     * Weights should be a non-empty array with all elements greater than zero. The weights will be
+     * normalized such that the sum of all elements equals to one.
+     */
+    Param<Double[]> WEIGHTS =
+            new DoubleArrayParam(
+                    "weights",
+                    "The weights of data splitting.",
+                    new Double[] {1.0, 1.0},
+                    weightsValidator());
+
+    default Double[] getWeights() {
+        return get(WEIGHTS);
+    }
+
+    default T setWeights(Double... value) {
+        return set(WEIGHTS, value);
+    }
+
+    // Checks the weights parameter.
+    static ParamValidator<Double[]> weightsValidator() {
+        return weights -> {
+            if (weights == null) {
+                return false;
+            }
+            for (Double weight : weights) {
+                if (weight <= 0.0) {
+                    return false;
+                }
+            }
+            return weights.length > 1;
+        };
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
new file mode 100644
index 0000000..aeb0245
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link RandomSplitter}. */
+public class RandomSplitterTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(1);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    private Table getTable(int size) {
+        DataStreamSource<Long> dataStream = env.fromSequence(0L, size);
+        return tEnv.fromDataStream(dataStream);
+    }
+
+    @Test
+    public void testParam() {
+        RandomSplitter splitter = new RandomSplitter();
+        splitter.setWeights(0.3, 0.4);
+        assertArrayEquals(new Double[] {0.3, 0.4}, splitter.getWeights());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+
+        RandomSplitter splitter = new RandomSplitter().setWeights(0.5, 0.1);
+        Table[] output = splitter.transform(tempTable);
+        assertEquals(2, output.length);
+        for (Table table : output) {
+            assertEquals(
+                    Arrays.asList("test_input", "dummy_input"),
+                    table.getResolvedSchema().getColumnNames());
+        }
+    }
+
+    @Test
+    public void testWeights() throws Exception {
+        Table data = getTable(1000);
+        RandomSplitter splitter = new RandomSplitter().setWeights(2.0, 1.0, 2.0);
+        Table[] output = splitter.transform(data);
+
+        List<Row> result0 = IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
+        List<Row> result1 = IteratorUtils.toList(tEnv.toDataStream(output[1]).executeAndCollect());
+        List<Row> result2 = IteratorUtils.toList(tEnv.toDataStream(output[2]).executeAndCollect());
+        assertEquals(result0.size() / 400.0, 1.0, 0.1);
+        assertEquals(result1.size() / 200.0, 1.0, 0.1);
+        assertEquals(result2.size() / 400.0, 1.0, 0.1);
+        verifyResultTables(data, output);
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        Table data = getTable(2000);
+        RandomSplitter randomSplitter = new RandomSplitter().setWeights(4.0, 6.0);
+
+        RandomSplitter splitterLoad =
+                TestUtils.saveAndReload(
+                        tEnv, randomSplitter, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+        Table[] output = splitterLoad.transform(data);
+        List<Row> result0 = IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
+        List<Row> result1 = IteratorUtils.toList(tEnv.toDataStream(output[1]).executeAndCollect());
+        assertEquals(result0.size() / 800.0, 1.0, 0.1);
+        assertEquals(result1.size() / 1200.0, 1.0, 0.1);
+        verifyResultTables(data, output);
+    }
+
+    private void verifyResultTables(Table input, Table[] output) throws Exception {
+        List<Row> expectedData = IteratorUtils.toList(tEnv.toDataStream(input).executeAndCollect());
+        List<Row> results = new ArrayList<>();
+        for (Table table : output) {
+            List<Row> result = IteratorUtils.toList(tEnv.toDataStream(table).executeAndCollect());
+            results.addAll(result);
+        }
+        assertEquals(expectedData.size(), results.size());
+        compareResultCollections(
+                expectedData, results, Comparator.comparingLong(row -> row.getFieldAs(0)));
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py b/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py
new file mode 100644
index 0000000..9cdabc9
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py
@@ -0,0 +1,63 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a RandomSplitter instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, 10, 0),
+        (1, 10, 0),
+        (1, 10, 0),
+        (4, 10, 0),
+        (5, 10, 0),
+        (6, 10, 0),
+        (7, 10, 0),
+        (10, 10, 0),
+        (13, 10, 0)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['f0', 'f1', "f2"],
+            [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a RandomSplitter object and initializes its parameters.
+splitter = RandomSplitter().set_weights(4.0, 6.0)
+
+# Uses the RandomSplitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+print("Split Result 1 (40%)")
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+    print(str(result))
+
+print("Split Result 2 (60%)")
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+    print(str(result))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/normalizer.py b/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
index 53c4e6b..2b99ee2 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/normalizer.py
@@ -44,7 +44,7 @@ class _NormalizerParams(
     def set_p(self, value: float):
         return typing.cast(_NormalizerParams, self.set(self.P, value))
 
-    def get_p(self) -> bool:
+    def get_p(self) -> float:
         return self.get(self.P)
 
     @property
diff --git a/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py b/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py
new file mode 100644
index 0000000..4c23891
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/randomsplitter.py
@@ -0,0 +1,80 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+from typing import Tuple
+from pyflink.ml.core.param import Param, FloatArrayParam, ParamValidator
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+
+
+class _RandomSplitterParams(
+    JavaWithParams
+):
+    """
+    Checks the weights parameter.
+    """
+    def weights_validator(self) -> ParamValidator[Tuple[float]]:
+        class WeightsValidator(ParamValidator[Tuple[float]]):
+            def validate(self, weights: Tuple[float]) -> bool:
+                for val in weights:
+                    if val <= 0.0:
+                        return False
+                return len(weights) > 1
+        return WeightsValidator()
+
+    """
+    Params for :class:`RandomSplitter`.
+    Weights should be a non-empty array with all elements greater than zero.
+    The weights will be normalized such that the sum of all elements equals
+    to one.
+    """
+    WEIGHTS: Param[Tuple[float]] = FloatArrayParam(
+        "weights",
+        "The weights of data splitting.",
+        [1.0, 1.0],
+        weights_validator(None))
+
+    def __init__(self, java_params):
+        super(_RandomSplitterParams, self).__init__(java_params)
+
+    def set_weights(self, *value: float):
+        return typing.cast(_RandomSplitterParams, self.set(self.WEIGHTS, value))
+
+    def get_weights(self) -> Tuple[float, ...]:
+        return self.get(self.WEIGHTS)
+
+    @property
+    def weights(self):
+        return self.get_weights()
+
+
+class RandomSplitter(JavaFeatureTransformer, _RandomSplitterParams):
+    """
+    An AlgoOperator which splits a table into N tables according to the given weights.
+    """
+
+    def __init__(self, java_model=None):
+        super(RandomSplitter, self).__init__(java_model)
+
+    @classmethod
+    def _java_transformer_package_name(cls) -> str:
+        return "randomsplitter"
+
+    @classmethod
+    def _java_transformer_class_name(cls) -> str:
+        return "RandomSplitter"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
index f765a2f..f8d07d7 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_normalizer.py
@@ -64,6 +64,7 @@ class NormalizerTest(PyFlinkMLTestCase):
 
         self.assertEqual("intput_vec", normalizer.get_input_col())
         self.assertEqual(1.5, normalizer.get_p())
+        self.assertEqual(float, type(normalizer.get_p()))
         self.assertEqual('output_vec', normalizer.get_output_col())
 
     def test_save_load_transform(self):
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py
new file mode 100644
index 0000000..bcf99c7
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_randomsplitter.py
@@ -0,0 +1,77 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class RandomSplitterTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(RandomSplitterTest, self).setUp()
+        data = []
+        for i in range(1, 10000):
+            data.append((i, ))
+        self.input_table = self.t_env.from_data_stream(
+            self.env.from_collection(
+                data,
+                type_info=Types.ROW_NAMED(
+                    ['f0', ],
+                    [Types.INT(), ])))
+
+    def test_param(self):
+        splitter = RandomSplitter()
+        splitter.set_weights(0.2, 0.8)
+        self.assertEqual(0.2, splitter.weights[0])
+        self.assertEqual(0.8, splitter.weights[1])
+
+    def test_output_schema(self):
+        splitter = RandomSplitter()
+        input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                ('', ''),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['test_input', 'dummy_input'],
+                    [Types.STRING(), Types.STRING()])))
+        output = splitter.set_weights(0.5, 0.5) \
+            .transform(input_data_table)[0]
+
+        self.assertEqual(
+            ['test_input', 'dummy_input'],
+            output.get_schema().get_field_names())
+
+    def test_transform(self):
+        splitter = RandomSplitter().set_weights(0.4, 0.6)
+
+        output = splitter.transform(self.input_table)
+        results = [result for result in self.t_env.to_data_stream(output[0]).execute_and_collect()]
+        self.assertAlmostEqual(len(results) / 4000.0, 1.0, delta=0.1)
+
+    def test_save_load_transform(self):
+        splitter = RandomSplitter().set_weights(0.4, 0.6)
+        path = os.path.join(self.temp_dir, 'test_save_load_random_splitter')
+        splitter.save(path)
+        splitter = RandomSplitter.load(self.t_env, path)
+
+        output = splitter.transform(self.input_table)
+        results = [result for result in self.t_env.to_data_stream(output[0]).execute_and_collect()]
+        self.assertAlmostEqual(len(results) / 4000.0, 1.0, delta=0.1)