You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/08/18 01:33:08 UTC
[flink] branch master updated: [FLINK-18947][python] Support
partition_custom() for Python DataStream API. (#13155)
This is an automated email from the ASF dual-hosted git repository.
hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 917c3d2 [FLINK-18947][python] Support partition_custom() for Python DataStream API. (#13155)
917c3d2 is described below
commit 917c3d2f294f17ea8031dce2220f142c67520463
Author: Shuiqiang Chen <ac...@alibaba-inc.com>
AuthorDate: Tue Aug 18 09:31:26 2020 +0800
[FLINK-18947][python] Support partition_custom() for Python DataStream API. (#13155)
---
flink-python/pom.xml | 1 +
flink-python/pyflink/datastream/data_stream.py | 89 +++++++++++++++++++++-
flink-python/pyflink/datastream/functions.py | 44 ++++++++++-
.../pyflink/datastream/tests/test_data_stream.py | 24 ++++++
.../python/PartitionCustomKeySelector.java | 36 +++++++++
...treamPythonPartitionCustomFunctionOperator.java | 75 ++++++++++++++++++
.../env/beam/ProcessPythonEnvironmentManager.java | 4 +
.../apache/flink/python/util/PythonConfigUtil.java | 26 ++++++-
.../python/AbstractPythonFunctionOperatorBase.java | 3 +-
.../util/PartitionCustomTestMapFunction.java | 48 ++++++++++++
10 files changed, 343 insertions(+), 7 deletions(-)
diff --git a/flink-python/pom.xml b/flink-python/pom.xml
index 3246761..3b5f5aa 100644
--- a/flink-python/pom.xml
+++ b/flink-python/pom.xml
@@ -460,6 +460,7 @@ under the License.
<includes>
<include>**/DataStreamTestCollectSink.class</include>
<include>**/MyCustomSourceFunction.class</include>
+ <include>**/PartitionCustomTestMapFunction.class</include>
</includes>
<outputDirectory>
${project.build.directory}/data-stream-test
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index e227a51..2c5f526 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-
+import os
from typing import Callable, Union
from pyflink.common import typeinfo, ExecutionConfig
@@ -24,7 +24,8 @@ from pyflink.common.typeinfo import TypeInformation
from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \
FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \
- ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction
+ ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction, Partitioner, \
+ PartitionerFunctionWrapper
from pyflink.java_gateway import get_gateway
@@ -445,6 +446,77 @@ class DataStream(object):
"""
return DataStream(self._j_data_stream.broadcast())
+ def partition_custom(self, partitioner: Union[Callable, Partitioner],
+ key_selector: Union[Callable, KeySelector]) -> 'DataStream':
+ """
+ Partitions a DataStream on the key returned by the selector, using a custom partitioner.
+ This method takes the key selector to get the key to partition on, and a partitioner that
+ accepts the key type.
+
+ Note that this method works only on single field keys, i.e. the selector cannet return
+ tuples of fields.
+
+ :param partitioner: The partitioner to assign partitions to keys.
+ :param key_selector: The KeySelector with which the DataStream is partitioned.
+ :return: The partitioned DataStream.
+ """
+ if callable(key_selector):
+ key_selector = KeySelectorFunctionWrapper(key_selector)
+ if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)):
+ raise TypeError("Parameter key_selector should be a type of KeySelector.")
+
+ if callable(partitioner):
+ partitioner = PartitionerFunctionWrapper(partitioner)
+ if not isinstance(partitioner, (Partitioner, PartitionerFunctionWrapper)):
+ raise TypeError("Parameter partitioner should be a type of Partitioner.")
+
+ gateway = get_gateway()
+ data_stream_num_partitions_env_key = gateway.jvm\
+ .org.apache.flink.datastream.runtime.operators.python\
+ .DataStreamPythonPartitionCustomFunctionOperator.DATA_STREAM_NUM_PARTITIONS
+
+ class PartitionCustomMapFunction(MapFunction):
+ """
+ A wrapper class for partition_custom map function. It indicates that it is a partition
+ custom operation that we need to apply DataStreamPythonPartitionCustomFunctionOperator
+ to run the map function.
+ """
+
+ def __init__(self):
+ self.num_partitions = None
+
+ def map(self, value):
+ return self.partition_custom_map(value)
+
+ def partition_custom_map(self, value):
+ if self.num_partitions is None:
+ self.num_partitions = int(os.environ[data_stream_num_partitions_env_key])
+ partition = partitioner.partition(key_selector.get_key(value), self.num_partitions)
+ return partition, value
+
+ def __repr__(self) -> str:
+ return '_Flink_PartitionCustomMapFunction'
+
+ original_type_info = self.get_type()
+ intermediate_map_stream = self.map(PartitionCustomMapFunction(),
+ type_info=Types.ROW([Types.INT(), original_type_info]))
+ intermediate_map_stream.name(
+ gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
+ .STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)
+
+ JPartitionCustomKeySelector = gateway.jvm\
+ .org.apache.flink.datastream.runtime.functions.python.PartitionCustomKeySelector
+ JIdParitioner = gateway.jvm\
+ .org.apache.flink.api.java.functions.IdPartitioner
+ intermediate_map_stream = DataStream(intermediate_map_stream._j_data_stream
+ .partitionCustom(JIdParitioner(),
+ JPartitionCustomKeySelector()))
+
+ values_map_stream = intermediate_map_stream.map(lambda x: x[1], original_type_info)
+ values_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
+ .KEYED_STREAM_VALUE_OPERATOR_NAME)
+ return DataStream(values_map_stream._j_data_stream)
+
def _get_java_python_function_operator(self, func: Union[Function, FunctionWrapper],
type_info: TypeInformation, func_name: str,
func_type: int):
@@ -506,8 +578,13 @@ class DataStream(object):
j_python_data_stream_function_info)
return j_python_data_stream_function_operator, j_output_type_info
else:
- DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
- .operators.python.DataStreamPythonStatelessFunctionOperator
+ if str(func) == '_Flink_PartitionCustomMapFunction':
+ DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+ .operators.python.DataStreamPythonPartitionCustomFunctionOperator
+ else:
+ DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+ .operators.python.DataStreamPythonStatelessFunctionOperator
+
j_python_data_stream_function_operator = DataStreamPythonFunctionOperator(
j_conf,
j_input_types,
@@ -729,6 +806,10 @@ class KeyedStream(DataStream):
def broadcast(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')
+ def partition_custom(self, partitioner: Union[Callable, Partitioner],
+ key_selector: Union[Callable, KeySelector]) -> 'DataStream':
+ raise Exception('Cannot override partitioning for KeyedStream.')
+
def print(self, sink_identifier=None):
return self._values().print()
diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py
index 76f5025..3f576d4 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -17,7 +17,7 @@
################################################################################
import abc
-from typing import Union
+from typing import Union, Any
from py4j.java_gateway import JavaObject
@@ -240,6 +240,23 @@ class FilterFunction(Function):
pass
+class Partitioner(Function):
+ """
+ Function to implement a custom partition assignment for keys.
+ """
+
+ @abc.abstractmethod
+ def partition(self, key: Any, num_partitions: int) -> int:
+ """
+ Computes the partition for the given key.
+
+ :param key: The key.
+ :param num_partitions: The number of partitions to partition into.
+ :return: The partition index.
+ """
+ pass
+
+
class FunctionWrapper(object):
"""
A basic wrapper class for user defined function.
@@ -360,6 +377,31 @@ class KeySelectorFunctionWrapper(FunctionWrapper):
return self._func(value)
+class PartitionerFunctionWrapper(FunctionWrapper):
+ """
+ A wrapper class for Partitioner. It's used for wrapping up user defined function in a
+ Partitioner when user does not implement a Partitioner but directly pass a function
+ object or a lambda function to partition_custom() function.
+ """
+ def __init__(self, func):
+ """
+ The constructor of PartitionerFunctionWrapper.
+
+ :param func: user defined function object.
+ """
+ super(PartitionerFunctionWrapper, self).__init__(func)
+
+ def partition(self, key: Any, num_partitions: int) -> int:
+ """
+ A delegated partition function to invoke user defined function.
+
+ :param key: The key.
+ :param num_partitions: The number of partitions to partition into.
+ :return: The partition index.
+ """
+ return self._func(key, num_partitions)
+
+
def _get_python_env():
"""
An util function to get a python user defined function execution environment.
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 2e352a7..655e847 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -19,6 +19,7 @@ import decimal
from pyflink.common.typeinfo import Types
from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.datastream.data_stream import DataStream
from pyflink.datastream.functions import FilterFunction
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
@@ -404,6 +405,29 @@ class DataStreamTests(PyFlinkTestCase):
pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy']
self.assertEqual(pre_ship_strategy, 'SHUFFLE')
+ def test_partition_custom(self):
+ ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2),
+ ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+
+ expected_num_partitions = 5
+
+ def my_partitioner(key, num_partitions):
+ assert expected_num_partitions, num_partitions
+ return key % num_partitions
+
+ partitioned_stream = ds.map(lambda x: x, type_info=Types.ROW([Types.STRING(),
+ Types.INT()]))\
+ .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1])
+
+ JPartitionCustomTestMapFunction = get_gateway().jvm\
+ .org.apache.flink.python.util.PartitionCustomTestMapFunction
+ test_map_stream = DataStream(partitioned_stream
+ ._j_data_stream.map(JPartitionCustomTestMapFunction()))
+ test_map_stream.set_parallelism(expected_num_partitions).add_sink(self.test_sink)
+
+ self.env.execute('test_partition_custom')
+
def test_keyed_stream_partitioning(self):
ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)])
keyed_stream = ds.key_by(lambda x: x[1])
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java
new file mode 100644
index 0000000..04391b3
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/functions/python/PartitionCustomKeySelector.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.functions.python;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.types.Row;
+
+/**
+ * The {@link PartitionCustomKeySelector} will return the first field of the input row value. The input value is
+ * generated by the
+ * {@link org.apache.flink.datastream.runtime.operators.python.DataStreamPythonPartitionCustomFunctionOperator} after
+ * executed user defined partitioner and keySelector function. The value of the first field will be the desired
+ * partition index.
+ */
+public class PartitionCustomKeySelector implements KeySelector<Row, Integer> {
+ @Override
+ public Integer getKey(Row value) throws Exception {
+ return (Integer) value.getField(0);
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java
new file mode 100644
index 0000000..2a39935
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonPartitionCustomFunctionOperator.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.operators.python;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.CoreOptions;
+import org.apache.flink.datastream.runtime.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.datastream.runtime.runners.python.beam.BeamDataStreamPythonStatelessFunctionRunner;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.python.env.PythonEnvironmentManager;
+import org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager;
+
+/**
+ * The {@link DataStreamPythonPartitionCustomFunctionOperator} enables us to set the number of partitions for current
+ * operator dynamically when generating the {@link org.apache.flink.streaming.api.graph.StreamGraph} before executing
+ * the job. The number of partitions will be set in environment variables for python Worker, so that we can obtain the
+ * number of partitions when executing user defined partitioner function.
+ */
+public class DataStreamPythonPartitionCustomFunctionOperator<IN, OUT> extends
+ DataStreamPythonStatelessFunctionOperator<IN, OUT> {
+
+ public static final String DATA_STREAM_NUM_PARTITIONS = "DATA_STREAM_NUM_PARTITIONS";
+
+ private int numPartitions = CoreOptions.DEFAULT_PARALLELISM.defaultValue();
+
+ public DataStreamPythonPartitionCustomFunctionOperator(
+ Configuration config,
+ TypeInformation<IN> inputTypeInfo,
+ TypeInformation<OUT> outputTypeInfo,
+ DataStreamPythonFunctionInfo pythonFunctionInfo) {
+ super(config, inputTypeInfo, outputTypeInfo, pythonFunctionInfo);
+ }
+
+ @Override
+ public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
+ PythonEnvironmentManager pythonEnvironmentManager = createPythonEnvironmentManager();
+ if (pythonEnvironmentManager instanceof ProcessPythonEnvironmentManager) {
+ ProcessPythonEnvironmentManager envManager = (ProcessPythonEnvironmentManager) pythonEnvironmentManager;
+ envManager.setEnvironmentVariable(DATA_STREAM_NUM_PARTITIONS,
+ String.valueOf(this.numPartitions));
+ }
+ return new BeamDataStreamPythonStatelessFunctionRunner(
+ getRuntimeContext().getTaskName(),
+ pythonEnvironmentManager,
+ inputTypeInfo,
+ outputTypeInfo,
+ DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN,
+ getUserDefinedDataStreamFunctionsProto(),
+ DATA_STREAM_MAP_FUNCTION_CODER_URN,
+ jobOptions,
+ getFlinkMetricContainer()
+ );
+ }
+
+ public void setNumPartitions(int numPartitions) {
+ this.numPartitions = numPartitions;
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java b/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
index be11fad..6d7ce46 100644
--- a/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
+++ b/flink-python/src/main/java/org/apache/flink/python/env/beam/ProcessPythonEnvironmentManager.java
@@ -240,6 +240,10 @@ public final class ProcessPythonEnvironmentManager implements PythonEnvironmentM
return env;
}
+ public void setEnvironmentVariable(String key, String value) {
+ this.systemEnv.put(key, value);
+ }
+
private void constructFilesDirectory(Map<String, String> env) throws IOException {
// link or copy python files to filesDirectory and add them to PYTHONPATH
List<String> pythonFilePaths = new ArrayList<>();
diff --git a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 6605959..f6d13c5 100644
--- a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++ b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -18,6 +18,7 @@
package org.apache.flink.python.util;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.datastream.runtime.operators.python.DataStreamPythonPartitionCustomFunctionOperator;
import org.apache.flink.datastream.runtime.operators.python.DataStreamPythonStatelessFunctionOperator;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -41,6 +42,7 @@ public class PythonConfigUtil {
public static final String KEYED_STREAM_VALUE_OPERATOR_NAME = "_keyed_stream_values_operator";
public static final String STREAM_KEY_BY_MAP_OPERATOR_NAME = "_stream_key_by_map_operator";
+ public static final String STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME = "_partition_custom_map_operator";
/**
* A static method to get the {@link StreamExecutionEnvironment} configuration merged with python dependency
@@ -94,7 +96,8 @@ public class PythonConfigUtil {
downStreamEdge.setPartitioner(new ForwardPartitioner());
}
- if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)) {
+ if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME) ||
+ streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
StreamNode upStreamNode = streamGraph.getStreamNode(upStreamEdge.getSourceId());
chainStreamNode(upStreamEdge, streamNode, upStreamNode);
@@ -139,9 +142,30 @@ public class PythonConfigUtil {
}
}
}
+
+ setStreamPartitionCustomOperatorNumPartitions(streamNodes, streamGraph);
+
return streamGraph;
}
+ private static void setStreamPartitionCustomOperatorNumPartitions(
+ Collection<StreamNode> streamNodes, StreamGraph streamGraph){
+ for (StreamNode streamNode : streamNodes) {
+ StreamOperatorFactory streamOperatorFactory = streamNode.getOperatorFactory();
+ if (streamOperatorFactory instanceof SimpleOperatorFactory) {
+ StreamOperator streamOperator = ((SimpleOperatorFactory) streamOperatorFactory).getOperator();
+ if (streamOperator instanceof DataStreamPythonPartitionCustomFunctionOperator) {
+ DataStreamPythonPartitionCustomFunctionOperator paritionCustomFunctionOperator =
+ (DataStreamPythonPartitionCustomFunctionOperator) streamOperator;
+
+ // Update the numPartitions of PartitionCustomOperator after aligned all operators.
+ paritionCustomFunctionOperator.setNumPartitions(
+ streamGraph.getStreamNode(streamNode.getOutEdges().get(0).getTargetId()).getParallelism());
+ }
+ }
+ }
+ }
+
/**
* Generator a new {@link PythonConfig} with the combined config which is derived from oldConfig.
*/
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 280bf0a..5a989ba 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -38,6 +38,7 @@ import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
+import java.util.HashMap;
import java.util.concurrent.ScheduledFuture;
/**
@@ -336,7 +337,7 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
return new ProcessPythonEnvironmentManager(
dependencyInfo,
getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(),
- System.getenv());
+ new HashMap<>(System.getenv()));
} else {
throw new UnsupportedOperationException(String.format(
"Execution type '%s' is not supported.", pythonEnv.getExecType()));
diff --git a/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java b/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java
new file mode 100644
index 0000000..5484194
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/python/util/PartitionCustomTestMapFunction.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.python.util;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.types.Row;
+
+/**
+ * {@link PartitionCustomTestMapFunction} is a dedicated MapFunction to make sure the specific field data is equal to
+ * current sub-task index.
+ */
+public class PartitionCustomTestMapFunction extends RichMapFunction<Row, Row> {
+
+ private int currentTaskIndex;
+
+ @Override
+ public void open(Configuration parameters) {
+ this.currentTaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+ }
+
+ @Override
+ public Row map(Row value) throws Exception {
+ int expectedPartitionIndex = (Integer) (value.getField(1)) % getRuntimeContext()
+ .getNumberOfParallelSubtasks();
+ if (expectedPartitionIndex != currentTaskIndex) {
+ throw new RuntimeException(String.format("the data: Row<%s> was sent to the wrong partition[%d], " +
+ "expected partition is [%d].", value.toString(), currentTaskIndex, expectedPartitionIndex));
+ }
+ return value;
+ }
+}