You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/08/18 01:33:08 UTC

[flink] branch master updated: [FLINK-18947][python] Support partition_custom() for Python DataStream API. (#13155)

This is an automated email from the ASF dual-hosted git repository.

hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 917c3d2  [FLINK-18947][python] Support partition_custom() for Python DataStream API. (#13155)
917c3d2 is described below

commit 917c3d2f294f17ea8031dce2220f142c67520463
Author: Shuiqiang Chen <ac...@alibaba-inc.com>
AuthorDate: Tue Aug 18 09:31:26 2020 +0800

    [FLINK-18947][python] Support partition_custom() for Python DataStream API. (#13155)
---
 flink-python/pom.xml                               |  1 +
 flink-python/pyflink/datastream/data_stream.py     | 89 +++++++++++++++++++++-
 flink-python/pyflink/datastream/functions.py       | 44 ++++++++++-
 .../pyflink/datastream/tests/test_data_stream.py   | 24 ++++++
 .../python/PartitionCustomKeySelector.java         | 36 +++++++++
 ...treamPythonPartitionCustomFunctionOperator.java | 75 ++++++++++++++++++
 .../env/beam/ProcessPythonEnvironmentManager.java  |  4 +
 .../apache/flink/python/util/PythonConfigUtil.java | 26 ++++++-
 .../python/AbstractPythonFunctionOperatorBase.java |  3 +-
 .../util/PartitionCustomTestMapFunction.java       | 48 ++++++++++++
 10 files changed, 343 insertions(+), 7 deletions(-)

diff --git a/flink-python/pom.xml b/flink-python/pom.xml
index 3246761..3b5f5aa 100644
--- a/flink-python/pom.xml
+++ b/flink-python/pom.xml
@@ -460,6 +460,7 @@ under the License.
 							<includes>
 								<include>**/DataStreamTestCollectSink.class</include>
 								<include>**/MyCustomSourceFunction.class</include>
+								<include>**/PartitionCustomTestMapFunction.class</include>
 							</includes>
 							<outputDirectory>
 								${project.build.directory}/data-stream-test
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index e227a51..2c5f526 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -15,7 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
-
+import os
 from typing import Callable, Union
 
 from pyflink.common import typeinfo, ExecutionConfig
@@ -24,7 +24,8 @@ from pyflink.common.typeinfo import TypeInformation
 from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
     MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \
     FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \
-    ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction
+    ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction, Partitioner, \
+    PartitionerFunctionWrapper
 from pyflink.java_gateway import get_gateway
 
 
@@ -445,6 +446,77 @@ class DataStream(object):
         """
         return DataStream(self._j_data_stream.broadcast())
 
+    def partition_custom(self, partitioner: Union[Callable, Partitioner],
+                         key_selector: Union[Callable, KeySelector]) -> 'DataStream':
+        """
+        Partitions a DataStream on the key returned by the selector, using a custom partitioner.
+        This method takes the key selector to get the key to partition on, and a partitioner that
+        accepts the key type.
+
+        Note that this method works only on single field keys, i.e. the selector cannet return
+        tuples of fields.
+
+        :param partitioner: The partitioner to assign partitions to keys.
+        :param key_selector: The KeySelector with which the DataStream is partitioned.
+        :return: The partitioned DataStream.
+        """
+        if callable(key_selector):
+            key_selector = KeySelectorFunctionWrapper(key_selector)
+        if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)):
+            raise TypeError("Parameter key_selector should be a type of KeySelector.")
+
+        if callable(partitioner):
+            partitioner = PartitionerFunctionWrapper(partitioner)
+        if not isinstance(partitioner, (Partitioner, PartitionerFunctionWrapper)):
+            raise TypeError("Parameter partitioner should be a type of Partitioner.")
+
+        gateway = get_gateway()
+        data_stream_num_partitions_env_key = gateway.jvm\
+            .org.apache.flink.datastream.runtime.operators.python\
+            .DataStreamPythonPartitionCustomFunctionOperator.DATA_STREAM_NUM_PARTITIONS
+
+        class PartitionCustomMapFunction(MapFunction):
+            """
+            A wrapper class for partition_custom map function. It indicates that it is a partition
+            custom operation that we need to apply DataStreamPythonPartitionCustomFunctionOperator
+            to run the map function.
+            """
+
+            def __init__(self):
+                self.num_partitions = None
+
+            def map(self, value):
+                return self.partition_custom_map(value)
+
+            def partition_custom_map(self, value):
+                if self.num_partitions is None:
+                    self.num_partitions = int(os.environ[data_stream_num_partitions_env_key])
+                partition = partitioner.partition(key_selector.get_key(value), self.num_partitions)
+                return partition, value
+
+            def __repr__(self) -> str:
+                return '_Flink_PartitionCustomMapFunction'
+
+        original_type_info = self.get_type()
+        intermediate_map_stream = self.map(PartitionCustomMapFunction(),
+                                           type_info=Types.ROW([Types.INT(), original_type_info]))
+        intermediate_map_stream.name(
+            gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
+            .STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)
+
+        JPartitionCustomKeySelector = gateway.jvm\
+            .org.apache.flink.datastream.runtime.functions.python.PartitionCustomKeySelector
+        JIdParitioner = gateway.jvm\
+            .org.apache.flink.api.java.functions.IdPartitioner
+        intermediate_map_stream = DataStream(intermediate_map_stream._j_data_stream
+                                             .partitionCustom(JIdParitioner(),
+                                                              JPartitionCustomKeySelector()))
+
+        values_map_stream = intermediate_map_stream.map(lambda x: x[1], original_type_info)
+        values_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
+                               .KEYED_STREAM_VALUE_OPERATOR_NAME)
+        return DataStream(values_map_stream._j_data_stream)
+
     def _get_java_python_function_operator(self, func: Union[Function, FunctionWrapper],
                                            type_info: TypeInformation, func_name: str,
                                            func_type: int):
@@ -506,8 +578,13 @@ class DataStream(object):
                 j_python_data_stream_function_info)
             return j_python_data_stream_function_operator, j_output_type_info
         else:
-            DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
-                .operators.python.DataStreamPythonStatelessFunctionOperator
+            if str(func) == '_Flink_PartitionCustomMapFunction':
+                DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+                    .operators.python.DataStreamPythonPartitionCustomFunctionOperator
+            else:
+                DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+                    .operators.python.DataStreamPythonStatelessFunctionOperator
+
             j_python_data_stream_function_operator = DataStreamPythonFunctionOperator(
                 j_conf,
                 j_input_types,
@@ -729,6 +806,10 @@ class KeyedStream(DataStream):
     def broadcast(self) -> 'DataStream':
         raise Exception('Cannot override partitioning for KeyedStream.')
 
+    def partition_custom(self, partitioner: Union[Callable, Partitioner],
+                         key_selector: Union[Callable, KeySelector]) -> 'DataStream':
+        raise Exception('Cannot override partitioning for KeyedStream.')
+
     def print(self, sink_identifier=None):
         return self._values().print()
 
diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py
index 76f5025..3f576d4 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -17,7 +17,7 @@
 ################################################################################
 
 import abc
-from typing import Union
+from typing import Union, Any
 
 from py4j.java_gateway import JavaObject
 
@@ -240,6 +240,23 @@ class FilterFunction(Function):
         pass
 
 
+class Partitioner(Function):
+    """
+    Function to implement a custom partition assignment for keys.
+    """
+
+    @abc.abstractmethod
+    def partition(self, key: Any, num_partitions: int) -> int:
+        """
+        Computes the partition for the given key.
+
+        :param key: The key.
+        :param num_partitions: The number of partitions to partition into.
+        :return: The partition index.
+        """
+        pass
+
+
 class FunctionWrapper(object):
     """
     A basic wrapper class for user defined function.
@@ -360,6 +377,31 @@ class KeySelectorFunctionWrapper(FunctionWrapper):
         return self._func(value)
 
 
+class PartitionerFunctionWrapper(FunctionWrapper):
+    """
+    A wrapper class for Partitioner. It's used for wrapping up user defined function in a
+    Partitioner when user does not implement a Partitioner but directly pass a function
+    object or a lambda function to partition_custom() function.
+    """
+    def __init__(self, func):
+        """
+        The constructor of PartitionerFunctionWrapper.
+
+        :param func: user defined function object.
+        """
+        super(PartitionerFunctionWrapper, self).__init__(func)
+
+    def partition(self, key: Any, num_partitions: int) -> int:
+        """
+        A delegated partition function to invoke user defined function.
+
+        :param key: The key.
+        :param num_partitions: The number of partitions to partition into.
+        :return: The partition index.
+        """
+        return self._func(key, num_partitions)
+
+
 def _get_python_env():
     """
     An util function to get a python user defined function execution environment.
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 2e352a7..655e847 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -19,6 +19,7 @@ import decimal
 
 from pyflink.common.typeinfo import Types
 from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.datastream.data_stream import DataStream
 from pyflink.datastream.functions import FilterFunction
 from pyflink.datastream.functions import KeySelector
 from pyflink.datastream.functions import MapFunction, FlatMapFunction
@@ -404,6 +405,29 @@ class DataStreamTests(PyFlinkTestCase):
         pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy']
         self.assertEqual(pre_ship_strategy, 'SHUFFLE')
 
+    def test_partition_custom(self):
+        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2),
+                                       ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)],
+                                      type_info=Types.ROW([Types.STRING(), Types.INT()]))
+
+        expected_num_partitions = 5
+
+        def my_partitioner(key, num_partitions):
+            assert expected_num_partitions, num_partitions
+            return key % num_partitions
+
+        partitioned_stream = ds.map(lambda x: x, type_info=Types.ROW([Types.STRING(),
+                                                                      Types.INT()]))\
+            .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1])
+
+        JPartitionCustomTestMapFunction = get_gateway().jvm\
+            .org.apache.flink.python.util.PartitionCustomTestMapFunction
+        test_map_stream = DataStream(partitioned_stream
+                                     ._j_data_stream.map(JPartitionCustomTestMapFunction()))
+        test_map_stream.set_parallelism(expected_num_partitions).add_sink(self.test_sink)
+
+        self.env.execute('test_partition_custom')
+
     def test_keyed_stream_partitioning(self):
         ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)])
         keyed_stream = ds.key_by(lambda x: x[1])
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java
new file mode 100644
index 0000000..04391b3
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.functions.python;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.types.Row;
+
+/**
+ * The {@link PartitionCustomKeySelector} will return the first field of the input row value. The input value is
+ * generated by the
+ * {@link org.apache.flink.datastream.runtime.operators.python.DataStreamPythonPartitionCustomFunctionOperator} after
+ * executed user defined partitioner and keySelector function. The value of the first field will be the desired
+ * partition index.
+ */
+public class PartitionCustomKeySelector implements KeySelector<Row, Integer> {
+	@Override
+	public Integer getKey(Row value) throws Exception {
+		return (Integer) value.getField(0);
+	}
+}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java
new file mode 100644
index 0000000..2a39935
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.operators.python;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.CoreOptions;
+import org.apache.flink.datastream.runtime.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.datastream.runtime.runners.python.beam.BeamDataStreamPythonStatelessFunctionRunner;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.python.env.PythonEnvironmentManager;
+import org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager;
+
+/**
+ * The {@link DataStreamPythonPartitionCustomFunctionOperator} enables us to set the number of partitions for current
+ * operator dynamically when generating the {@link org.apache.flink.streaming.api.graph.StreamGraph} before executing
+ * the job. The number of partitions will be set in environment variables for python Worker, so that we can obtain the
+ * number of partitions when executing user defined partitioner function.
+ */
+public class DataStreamPythonPartitionCustomFunctionOperator<IN, OUT> extends
+	DataStreamPythonStatelessFunctionOperator<IN, OUT> {
+
+	public static final String DATA_STREAM_NUM_PARTITIONS = "DATA_STREAM_NUM_PARTITIONS";
+
+	private int numPartitions = CoreOptions.DEFAULT_PARALLELISM.defaultValue();
+
+	public DataStreamPythonPartitionCustomFunctionOperator(
+		Configuration config,
+		TypeInformation<IN> inputTypeInfo,
+		TypeInformation<OUT> outputTypeInfo,
+		DataStreamPythonFunctionInfo pythonFunctionInfo) {
+		super(config, inputTypeInfo, outputTypeInfo, pythonFunctionInfo);
+	}
+
+	@Override
+	public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
+		PythonEnvironmentManager pythonEnvironmentManager = createPythonEnvironmentManager();
+		if (pythonEnvironmentManager instanceof ProcessPythonEnvironmentManager) {
+			ProcessPythonEnvironmentManager envManager = (ProcessPythonEnvironmentManager) pythonEnvironmentManager;
+			envManager.setEnvironmentVariable(DATA_STREAM_NUM_PARTITIONS,
+				String.valueOf(this.numPartitions));
+		}
+		return new BeamDataStreamPythonStatelessFunctionRunner(
+			getRuntimeContext().getTaskName(),
+			pythonEnvironmentManager,
+			inputTypeInfo,
+			outputTypeInfo,
+			DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN,
+			getUserDefinedDataStreamFunctionsProto(),
+			DATA_STREAM_MAP_FUNCTION_CODER_URN,
+			jobOptions,
+			getFlinkMetricContainer()
+		);
+	}
+
+	public void setNumPartitions(int numPartitions) {
+		this.numPartitions = numPartitions;
+	}
+}
diff --git a/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java b/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
index be11fad..6d7ce46 100644
--- a/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
+++ b/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
@@ -240,6 +240,10 @@ public final class ProcessPythonEnvironmentManager implements PythonEnvironmentM
 		return env;
 	}
 
+	public void setEnvironmentVariable(String key, String value) {
+		this.systemEnv.put(key, value);
+	}
+
 	private void constructFilesDirectory(Map<String, String> env) throws IOException {
 		// link or copy python files to filesDirectory and add them to PYTHONPATH
 		List<String> pythonFilePaths = new ArrayList<>();
diff --git a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 6605959..f6d13c5 100644
--- a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++ b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -18,6 +18,7 @@
 package org.apache.flink.python.util;
 
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.datastream.runtime.operators.python.DataStreamPythonPartitionCustomFunctionOperator;
 import org.apache.flink.datastream.runtime.operators.python.DataStreamPythonStatelessFunctionOperator;
 import org.apache.flink.python.PythonConfig;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -41,6 +42,7 @@ public class PythonConfigUtil {
 
 	public static final String KEYED_STREAM_VALUE_OPERATOR_NAME = "_keyed_stream_values_operator";
 	public static final String STREAM_KEY_BY_MAP_OPERATOR_NAME = "_stream_key_by_map_operator";
+	public static final String STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME = "_partition_custom_map_operator";
 
 	/**
 	 * A static method to get the {@link StreamExecutionEnvironment} configuration merged with python dependency
@@ -94,7 +96,8 @@ public class PythonConfigUtil {
 			downStreamEdge.setPartitioner(new ForwardPartitioner());
 		}
 
-		if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)) {
+		if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME) ||
+		streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
 			StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
 			StreamNode upStreamNode = streamGraph.getStreamNode(upStreamEdge.getSourceId());
 			chainStreamNode(upStreamEdge, streamNode, upStreamNode);
@@ -139,9 +142,30 @@ public class PythonConfigUtil {
 				}
 			}
 		}
+
+		setStreamPartitionCustomOperatorNumPartitions(streamNodes, streamGraph);
+
 		return streamGraph;
 	}
 
+	private static void setStreamPartitionCustomOperatorNumPartitions(
+		Collection<StreamNode> streamNodes, StreamGraph streamGraph){
+		for (StreamNode streamNode : streamNodes) {
+			StreamOperatorFactory streamOperatorFactory = streamNode.getOperatorFactory();
+			if (streamOperatorFactory instanceof SimpleOperatorFactory) {
+				StreamOperator streamOperator = ((SimpleOperatorFactory) streamOperatorFactory).getOperator();
+				if (streamOperator instanceof DataStreamPythonPartitionCustomFunctionOperator) {
+					DataStreamPythonPartitionCustomFunctionOperator paritionCustomFunctionOperator =
+						(DataStreamPythonPartitionCustomFunctionOperator) streamOperator;
+
+					// Update the numPartitions of PartitionCustomOperator after aligned all operators.
+					paritionCustomFunctionOperator.setNumPartitions(
+						streamGraph.getStreamNode(streamNode.getOutEdges().get(0).getTargetId()).getParallelism());
+				}
+			}
+		}
+	}
+
 	/**
 	 * Generator a new {@link  PythonConfig} with the combined config which is derived from oldConfig.
 	 */
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 280bf0a..5a989ba 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -38,6 +38,7 @@ import org.apache.flink.table.functions.python.PythonEnv;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.concurrent.ScheduledFuture;
 
 /**
@@ -336,7 +337,7 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
 			return new ProcessPythonEnvironmentManager(
 				dependencyInfo,
 				getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(),
-				System.getenv());
+				new HashMap<>(System.getenv()));
 		} else {
 			throw new UnsupportedOperationException(String.format(
 				"Execution type '%s' is not supported.", pythonEnv.getExecType()));
diff --git a/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java b/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java
new file mode 100644
index 0000000..5484194
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.python.util;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.types.Row;
+
+/**
+ * {@link PartitionCustomTestMapFunction} is a dedicated MapFunction to make sure the specific field data is equal to
+ * current sub-task index.
+ */
+public class PartitionCustomTestMapFunction extends RichMapFunction<Row, Row> {
+
+	private int currentTaskIndex;
+
+	@Override
+	public void open(Configuration parameters) {
+		this.currentTaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+	}
+
+	@Override
+	public Row map(Row value) throws Exception {
+		int expectedPartitionIndex = (Integer) (value.getField(1)) % getRuntimeContext()
+			.getNumberOfParallelSubtasks();
+		if (expectedPartitionIndex != currentTaskIndex) {
+			throw new RuntimeException(String.format("the data: Row<%s> was sent to the wrong partition[%d], " +
+				"expected partition is [%d].", value.toString(), currentTaskIndex, expectedPartitionIndex));
+		}
+		return value;
+	}
+}