You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2019/08/08 08:39:22 UTC
[flink] branch release-1.9 updated: [FLINK-13594][python] Improve
the 'from_element' method of flink python api to apply to blink planner.
This is an automated email from the ASF dual-hosted git repository.
hequn pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push:
new 62d34f1 [FLINK-13594][python] Improve the 'from_element' method of flink python api to apply to blink planner.
62d34f1 is described below
commit 62d34f16d1648410db12ddbfd6fa2252d1913110
Author: Wei Zhong <we...@gmail.com>
AuthorDate: Tue Aug 6 15:27:17 2019 +0800
[FLINK-13594][python] Improve the 'from_element' method of flink python api to apply to blink planner.
This closes #9370
---
flink-python/pyflink/table/table_environment.py | 38 +++++---
flink-python/pyflink/table/tests/test_calc.py | 52 +++++++++-
flink-python/pyflink/table/tests/test_types.py | 58 ++++++++++-
flink-python/pyflink/table/types.py | 19 ++--
.../flink/api/common/python/PythonBridgeUtils.java | 108 ++++++---------------
.../flink/table/util/python/PythonTableUtils.scala | 87 ++++++++---------
.../runtime/stream/table/TableSinkITCase.scala | 8 +-
7 files changed, 214 insertions(+), 156 deletions(-)
diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py
index 85e31c1..915f9ab 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -668,12 +668,21 @@ class TableEnvironment(object):
serializer.dump_to_stream(elements, temp_file)
finally:
temp_file.close()
- return self._from_file(temp_file.name, schema)
+ row_type_info = _to_java_type(schema)
+ execution_config = self._get_execution_config(temp_file.name, schema)
+ gateway = get_gateway()
+ j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True)
+ j_input_format = gateway.jvm.PythonTableUtils.getInputFormat(
+ j_objs, row_type_info, execution_config)
+ j_table_source = gateway.jvm.PythonInputFormatTableSource(
+ j_input_format, row_type_info)
+
+ return Table(self._j_tenv.fromTableSource(j_table_source))
finally:
os.unlink(temp_file.name)
@abstractmethod
- def _from_file(self, filename, schema):
+ def _get_execution_config(self, filename, schema):
pass
@@ -683,12 +692,8 @@ class StreamTableEnvironment(TableEnvironment):
self._j_tenv = j_tenv
super(StreamTableEnvironment, self).__init__(j_tenv)
- def _from_file(self, filename, schema):
- gateway = get_gateway()
- jds = gateway.jvm.PythonBridgeUtils.createDataStreamFromFile(
- self._j_tenv.execEnv(), filename, True)
- return Table(gateway.jvm.PythonTableUtils.fromDataStream(
- self._j_tenv, jds, _to_java_type(schema)))
+ def _get_execution_config(self, filename, schema):
+ return self._j_tenv.execEnv().getConfig()
def get_config(self):
"""
@@ -796,18 +801,19 @@ class BatchTableEnvironment(TableEnvironment):
self._j_tenv = j_tenv
super(BatchTableEnvironment, self).__init__(j_tenv)
- def _from_file(self, filename, schema):
+ def _get_execution_config(self, filename, schema):
gateway = get_gateway()
blink_t_env_class = get_java_class(
gateway.jvm.org.apache.flink.table.api.internal.TableEnvironmentImpl)
- if blink_t_env_class == self._j_tenv.getClass():
- raise NotImplementedError("The operation 'from_elements' in batch mode is currently "
- "not supported when using blink planner.")
+ is_blink = (blink_t_env_class == self._j_tenv.getClass())
+ if is_blink:
+ # we can not get ExecutionConfig object from the TableEnvironmentImpl
+ # for the moment, just create a new ExecutionConfig.
+ execution_config = gateway.jvm.org.apache.flink.api.common.ExecutionConfig()
else:
- jds = gateway.jvm.PythonBridgeUtils.createDataSetFromFile(
- self._j_tenv.execEnv(), filename, True)
- return Table(gateway.jvm.PythonTableUtils.fromDataSet(
- self._j_tenv, jds, _to_java_type(schema)))
+ execution_config = self._j_tenv.execEnv().getConfig()
+
+ return execution_config
def get_config(self):
"""
diff --git a/flink-python/pyflink/table/tests/test_calc.py b/flink-python/pyflink/table/tests/test_calc.py
index b67e083..6699efd 100644
--- a/flink-python/pyflink/table/tests/test_calc.py
+++ b/flink-python/pyflink/table/tests/test_calc.py
@@ -20,7 +20,7 @@ import array
import datetime
from decimal import Decimal
-from pyflink.table import DataTypes, Row
+from pyflink.table import DataTypes, Row, BatchTableEnvironment, EnvironmentSettings
from pyflink.table.tests.test_types import ExamplePoint, PythonOnlyPoint, ExamplePointUDT, \
PythonOnlyUDT
from pyflink.testing import source_sink_utils
@@ -97,14 +97,60 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
PythonOnlyPoint(3.0, 4.0))],
schema)
t.insert_into("Results")
- self.t_env.execute("test")
+ t_env.execute("test")
actual = source_sink_utils.results()
expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
- '1970-01-02 00:00:00.0,86400000010,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
+ '1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
'1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]']
self.assert_equals(actual, expected)
+ def test_blink_from_element(self):
+ t_env = BatchTableEnvironment.create(environment_settings=EnvironmentSettings
+ .new_instance().use_blink_planner()
+ .in_batch_mode().build())
+ field_names = ["a", "b", "c", "d", "e", "f", "g", "h",
+ "i", "j", "k", "l", "m", "n", "o", "p", "q", "r"]
+ field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(),
+ DataTypes.STRING(), DataTypes.DATE(),
+ DataTypes.TIME(),
+ DataTypes.TIMESTAMP(),
+ DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(),
+ DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND()),
+ DataTypes.ARRAY(DataTypes.DOUBLE()),
+ DataTypes.ARRAY(DataTypes.DOUBLE(False)),
+ DataTypes.ARRAY(DataTypes.STRING()),
+ DataTypes.ARRAY(DataTypes.DATE()),
+ DataTypes.DECIMAL(10, 0),
+ DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()),
+ DataTypes.FIELD("b", DataTypes.DOUBLE())]),
+ DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()),
+ DataTypes.BYTES(),
+ PythonOnlyUDT()]
+ schema = DataTypes.ROW(
+ list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type),
+ field_names,
+ field_types)))
+ table_sink = source_sink_utils.TestAppendSink(field_names, field_types)
+ t_env.register_table_sink("Results", table_sink)
+ t = t_env.from_elements(
+ [(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0),
+ datetime.datetime(1970, 1, 2, 0, 0), datetime.datetime(1970, 1, 2, 0, 0),
+ datetime.timedelta(days=1, microseconds=10),
+ [1.0, None], array.array("d", [1.0, 2.0]),
+ ["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0),
+ {"key": 1.0}, bytearray(b'ABCD'),
+ PythonOnlyPoint(3.0, 4.0))],
+ schema)
+ t.insert_into("Results")
+ t_env.execute("test")
+ actual = source_sink_utils.results()
+
+ expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
+ '1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
+ '1.000000000000000000,1,2.0,{key=1.0},[65, 66, 67, 68],[3.0, 4.0]']
+ self.assert_equals(actual, expected)
+
if __name__ == '__main__':
import unittest
diff --git a/flink-python/pyflink/table/tests/test_types.py b/flink-python/pyflink/table/tests/test_types.py
index 8972fa8..9b1cbf7 100644
--- a/flink-python/pyflink/table/tests/test_types.py
+++ b/flink-python/pyflink/table/tests/test_types.py
@@ -21,8 +21,11 @@ import ctypes
import datetime
import pickle
import sys
+import tempfile
import unittest
+from pyflink.serializers import BatchedSerializer, PickleSerializer
+
from pyflink.java_gateway import get_gateway
from pyflink.table.types import (_infer_schema_from_data, _infer_type,
_array_signed_int_typecode_ctype_mappings,
@@ -825,10 +828,26 @@ class DataTypeConvertTests(unittest.TestCase):
DataTypes.DECIMAL(20, 10, False)]
self.assertEqual(converted_python_types, expected)
+ # Legacy type tests
+ Types = gateway.jvm.org.apache.flink.table.api.Types
+ BlinkBigDecimalTypeInfo = \
+ gateway.jvm.org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
+
+ java_types = [Types.STRING(),
+ Types.DECIMAL(),
+ BlinkBigDecimalTypeInfo(12, 5)]
+
+ converted_python_types = [_from_java_type(item) for item in java_types]
+
+ expected = [DataTypes.VARCHAR(2147483647),
+ DataTypes.DECIMAL(10, 0),
+ DataTypes.DECIMAL(12, 5)]
+ self.assertEqual(converted_python_types, expected)
+
def test_array_type(self):
+ # nullable/not_null flag will be lost during the conversion.
test_types = [DataTypes.ARRAY(DataTypes.BIGINT()),
- # array type with not null basic data type means primitive array
- DataTypes.ARRAY(DataTypes.BIGINT().not_null()),
+ DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.STRING()),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))]
@@ -879,6 +898,41 @@ class DataTypeConvertTests(unittest.TestCase):
self.assertEqual(test_types, converted_python_types)
+class DataSerializerTests(unittest.TestCase):
+
+ def test_java_pickle_deserializer(self):
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
+ serializer = PickleSerializer()
+ data = [(1, 2), (3, 4), (5, 6), (7, 8)]
+
+ try:
+ serializer.dump_to_stream(data, temp_file)
+ finally:
+ temp_file.close()
+
+ gateway = get_gateway()
+ result = [tuple(int_pair) for int_pair in
+ list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, False))]
+
+ self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
+
+ def test_java_batch_deserializer(self):
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
+ serializer = BatchedSerializer(PickleSerializer(), 2)
+ data = [(1, 2), (3, 4), (5, 6), (7, 8)]
+
+ try:
+ serializer.dump_to_stream(data, temp_file)
+ finally:
+ temp_file.close()
+
+ gateway = get_gateway()
+ result = [tuple(int_pair) for int_pair in
+ list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True))]
+
+ self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
+
+
if __name__ == "__main__":
try:
import xmlrunner
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index af48049..dbf92f8 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -1665,17 +1665,7 @@ def _to_java_type(data_type):
# ArrayType
elif isinstance(data_type, ArrayType):
- if type(data_type.element_type) in _primitive_array_element_types:
- if data_type.element_type._nullable is False:
- return Types.PRIMITIVE_ARRAY(_to_java_type(data_type.element_type))
- else:
- return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
- elif isinstance(data_type.element_type, VarCharType) or isinstance(
- data_type.element_type, CharType):
- return gateway.jvm.org.apache.flink.api.common.typeinfo.\
- BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
- else:
- return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
+ return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
# MapType
elif isinstance(data_type, MapType):
@@ -1783,8 +1773,15 @@ def _from_java_type(j_data_type):
type_info = logical_type.getTypeInformation()
BasicArrayTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.\
BasicArrayTypeInfo
+ BasicTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.BasicTypeInfo
if type_info == BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO:
data_type = DataTypes.ARRAY(DataTypes.STRING())
+ elif type_info == BasicTypeInfo.BIG_DEC_TYPE_INFO:
+ data_type = DataTypes.DECIMAL(10, 0)
+ elif type_info.getClass() == \
+ get_java_class(gateway.jvm.org.apache.flink.table.runtime.typeutils
+ .BigDecimalTypeInfo):
+ data_type = DataTypes.DECIMAL(type_info.precision(), type_info.scale())
else:
raise TypeError("Unsupported type: %s, it is recognized as a legacy type."
% type_info)
diff --git a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
index a082a8b..44a568b 100644
--- a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
@@ -17,16 +17,8 @@
package org.apache.flink.api.common.python;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.python.pickle.ArrayConstructor;
import org.apache.flink.api.common.python.pickle.ByteArrayConstructor;
-import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.util.Collector;
import net.razorvine.pickle.Unpickler;
@@ -39,38 +31,46 @@ import java.util.LinkedList;
import java.util.List;
/**
- * Utility class that contains helper methods to create a DataStream/DataSet from
+ * Utility class that contains helper methods to create a TableSource from
* a file which contains Python objects.
*/
public final class PythonBridgeUtils {
- /**
- * Creates a DataStream from a file which contains serialized python objects.
- */
- public static DataStream<Object[]> createDataStreamFromFile(
- final StreamExecutionEnvironment streamExecutionEnvironment,
- final String fileName,
- final boolean batched) throws IOException {
- return streamExecutionEnvironment
- .fromCollection(readPythonObjects(fileName))
- .flatMap(new PythonFlatMapFunction(batched))
- .returns(Types.GENERIC(Object[].class));
+ private static Object[] getObjectArrayFromUnpickledData(Object input) {
+ if (input.getClass().isArray()) {
+ return (Object[]) input;
+ } else {
+ return ((ArrayList<Object>) input).toArray(new Object[0]);
+ }
}
- /**
- * Creates a DataSet from a file which contains serialized python objects.
- */
- public static DataSet<Object[]> createDataSetFromFile(
- final ExecutionEnvironment executionEnvironment,
- final String fileName,
- final boolean batched) throws IOException {
- return executionEnvironment
- .fromCollection(readPythonObjects(fileName))
- .flatMap(new PythonFlatMapFunction(batched))
- .returns(Types.GENERIC(Object[].class));
+ public static List<Object[]> readPythonObjects(String fileName, boolean batched)
+ throws IOException {
+ List<byte[]> data = readPickledBytes(fileName);
+ Unpickler unpickle = new Unpickler();
+ initialize();
+ List<Object[]> unpickledData = new ArrayList<>();
+ for (byte[] pickledData: data) {
+ Object obj = unpickle.loads(pickledData);
+ if (batched) {
+ if (obj instanceof Object[]) {
+ Object[] arrayObj = (Object[]) obj;
+ for (Object o : arrayObj) {
+ unpickledData.add(getObjectArrayFromUnpickledData(o));
+ }
+ } else {
+ for (Object o : (ArrayList<Object>) obj) {
+ unpickledData.add(getObjectArrayFromUnpickledData(o));
+ }
+ }
+ } else {
+ unpickledData.add(getObjectArrayFromUnpickledData(obj));
+ }
+ }
+ return unpickledData;
}
- private static List<byte[]> readPythonObjects(final String fileName) throws IOException {
+ private static List<byte[]> readPickledBytes(final String fileName) throws IOException {
List<byte[]> objs = new LinkedList<>();
try (DataInputStream din = new DataInputStream(new FileInputStream(fileName))) {
try {
@@ -87,50 +87,6 @@ public final class PythonBridgeUtils {
return objs;
}
- private static final class PythonFlatMapFunction extends RichFlatMapFunction<byte[], Object[]> {
-
- private static final long serialVersionUID = 1L;
-
- private final boolean batched;
- private transient Unpickler unpickle;
-
- PythonFlatMapFunction(boolean batched) {
- this.batched = batched;
- initialize();
- }
-
- @Override
- public void open(Configuration parameters) {
- this.unpickle = new Unpickler();
- }
-
- @Override
- public void flatMap(byte[] value, Collector<Object[]> out) throws Exception {
- Object obj = unpickle.loads(value);
- if (batched) {
- if (obj instanceof Object[]) {
- for (int i = 0; i < ((Object[]) obj).length; i++) {
- collect(out, ((Object[]) obj)[i]);
- }
- } else {
- for (Object o : (ArrayList<Object>) obj) {
- collect(out, o);
- }
- }
- } else {
- collect(out, obj);
- }
- }
-
- private void collect(Collector<Object[]> out, Object obj) {
- if (obj.getClass().isArray()) {
- out.collect((Object[]) obj);
- } else {
- out.collect(((ArrayList<Object>) obj).toArray(new Object[0]));
- }
- }
- }
-
private static boolean initialized = false;
private static void initialize() {
synchronized (PythonBridgeUtils.class) {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
index 1ac2423..094945c 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
@@ -24,61 +24,36 @@ import java.time.{LocalDate, LocalDateTime, LocalTime}
import java.util.TimeZone
import java.util.function.BiConsumer
-import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.io.InputFormat
import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation}
-import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.io.CollectionInputFormat
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
-import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.java.{BatchTableEnvironment, StreamTableEnvironment}
-import org.apache.flink.table.api.{Table, Types}
+import org.apache.flink.core.io.InputSplit
+import org.apache.flink.table.api.{TableSchema, Types}
+import org.apache.flink.table.sources.InputFormatTableSource
import org.apache.flink.types.Row
+import scala.collection.JavaConversions._
+
object PythonTableUtils {
/**
- * Converts the given [[DataStream]] into a [[Table]].
- *
- * The schema of the [[Table]] is derived from the specified schemaString.
+ * Wrap the unpickled python data with an InputFormat. It will be passed to
+ * PythonInputFormatTableSource later.
*
- * @param tableEnv The table environment.
- * @param dataStream The [[DataStream]] to be converted.
- * @param dataType The type information of the table.
- * @return The converted [[Table]].
+ * @param data The unpickled python data.
+ * @param dataType The python data type.
+ * @param config The execution config used to create serializer.
+ * @return An InputFormat containing the python data.
*/
- def fromDataStream(
- tableEnv: StreamTableEnvironment,
- dataStream: DataStream[Array[Object]],
- dataType: TypeInformation[Row]): Table = {
- val convertedDataStream = dataStream.map(
- new MapFunction[Array[Object], Row] {
- override def map(value: Array[Object]): Row =
- convertTo(dataType).apply(value).asInstanceOf[Row]
- }).returns(dataType.asInstanceOf[TypeInformation[Row]])
-
- tableEnv.fromDataStream(convertedDataStream)
- }
-
- /**
- * Converts the given [[DataSet]] into a [[Table]].
- *
- * The schema of the [[Table]] is derived from the specified schemaString.
- *
- * @param tableEnv The table environment.
- * @param dataSet The [[DataSet]] to be converted.
- * @param dataType The type information of the table.
- * @return The converted [[Table]].
- */
- def fromDataSet(
- tableEnv: BatchTableEnvironment,
- dataSet: DataSet[Array[Object]],
- dataType: TypeInformation[Row]): Table = {
- val convertedDataSet = dataSet.map(
- new MapFunction[Array[Object], Row] {
- override def map(value: Array[Object]): Row =
- convertTo(dataType).apply(value).asInstanceOf[Row]
- }).returns(dataType.asInstanceOf[TypeInformation[Row]])
-
- tableEnv.fromDataSet(convertedDataSet)
+ def getInputFormat(
+ data: java.util.List[Array[Object]],
+ dataType: TypeInformation[Row],
+ config: ExecutionConfig): InputFormat[Row, _] = {
+ val converter = convertTo(dataType)
+ new CollectionInputFormat(data.map(converter(_).asInstanceOf[Row]),
+ dataType.createSerializer(config))
}
/**
@@ -422,3 +397,23 @@ object PythonTableUtils {
result
}
}
+
+/**
+ * An InputFormatTableSource created by python 'from_element' method.
+ *
+ * @param inputFormat The input format which contains the python data collection,
+ * usually created by PythonTableUtils#getInputFormat method
+ * @param rowTypeInfo The row type info of the python data.
+ * It is generated by the python 'from_element' method.
+ */
+class PythonInputFormatTableSource[Row](
+ inputFormat: InputFormat[Row, _ <: InputSplit],
+ rowTypeInfo: RowTypeInfo
+) extends InputFormatTableSource[Row] {
+
+ override def getInputFormat: InputFormat[Row, _ <: InputSplit] = inputFormat
+
+ override def getTableSchema: TableSchema = TableSchema.fromTypeInfo(rowTypeInfo)
+
+ override def getReturnType: TypeInformation[Row] = rowTypeInfo.asInstanceOf[TypeInformation[Row]]
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
index dbfdb14..cf3fcbf 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
@@ -27,7 +27,7 @@ import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
-import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.sink.SinkFunction
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
@@ -621,7 +621,11 @@ private[flink] class TestAppendSink extends AppendStreamTableSink[Row] {
var fTypes: Array[TypeInformation[_]] = _
override def emitDataStream(s: DataStream[Row]): Unit = {
- s.map(
+ consumeDataStream(s)
+ }
+
+ override def consumeDataStream(dataStream: DataStream[Row]): DataStreamSink[_] = {
+ dataStream.map(
new MapFunction[Row, JTuple2[JBool, Row]] {
override def map(value: Row): JTuple2[JBool, Row] = new JTuple2(true, value)
})