You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2019/08/08 08:39:22 UTC

[flink] branch release-1.9 updated: [FLINK-13594][python] Improve the 'from_element' method of flink python api to apply to blink planner.

This is an automated email from the ASF dual-hosted git repository.

hequn pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.9 by this push:
     new 62d34f1  [FLINK-13594][python] Improve the 'from_element' method of flink python api to apply to blink planner.
62d34f1 is described below

commit 62d34f16d1648410db12ddbfd6fa2252d1913110
Author: Wei Zhong <we...@gmail.com>
AuthorDate: Tue Aug 6 15:27:17 2019 +0800

    [FLINK-13594][python] Improve the 'from_element' method of flink python api to apply to blink planner.
    
    This closes #9370
---
 flink-python/pyflink/table/table_environment.py    |  38 +++++---
 flink-python/pyflink/table/tests/test_calc.py      |  52 +++++++++-
 flink-python/pyflink/table/tests/test_types.py     |  58 ++++++++++-
 flink-python/pyflink/table/types.py                |  19 ++--
 .../flink/api/common/python/PythonBridgeUtils.java | 108 ++++++---------------
 .../flink/table/util/python/PythonTableUtils.scala |  87 ++++++++---------
 .../runtime/stream/table/TableSinkITCase.scala     |   8 +-
 7 files changed, 214 insertions(+), 156 deletions(-)

diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py
index 85e31c1..915f9ab 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -668,12 +668,21 @@ class TableEnvironment(object):
                 serializer.dump_to_stream(elements, temp_file)
             finally:
                 temp_file.close()
-            return self._from_file(temp_file.name, schema)
+            row_type_info = _to_java_type(schema)
+            execution_config = self._get_execution_config(temp_file.name, schema)
+            gateway = get_gateway()
+            j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True)
+            j_input_format = gateway.jvm.PythonTableUtils.getInputFormat(
+                j_objs, row_type_info, execution_config)
+            j_table_source = gateway.jvm.PythonInputFormatTableSource(
+                j_input_format, row_type_info)
+
+            return Table(self._j_tenv.fromTableSource(j_table_source))
         finally:
             os.unlink(temp_file.name)
 
     @abstractmethod
-    def _from_file(self, filename, schema):
+    def _get_execution_config(self, filename, schema):
         pass
 
 
@@ -683,12 +692,8 @@ class StreamTableEnvironment(TableEnvironment):
         self._j_tenv = j_tenv
         super(StreamTableEnvironment, self).__init__(j_tenv)
 
-    def _from_file(self, filename, schema):
-        gateway = get_gateway()
-        jds = gateway.jvm.PythonBridgeUtils.createDataStreamFromFile(
-            self._j_tenv.execEnv(), filename, True)
-        return Table(gateway.jvm.PythonTableUtils.fromDataStream(
-            self._j_tenv, jds, _to_java_type(schema)))
+    def _get_execution_config(self, filename, schema):
+        return self._j_tenv.execEnv().getConfig()
 
     def get_config(self):
         """
@@ -796,18 +801,19 @@ class BatchTableEnvironment(TableEnvironment):
         self._j_tenv = j_tenv
         super(BatchTableEnvironment, self).__init__(j_tenv)
 
-    def _from_file(self, filename, schema):
+    def _get_execution_config(self, filename, schema):
         gateway = get_gateway()
         blink_t_env_class = get_java_class(
             gateway.jvm.org.apache.flink.table.api.internal.TableEnvironmentImpl)
-        if blink_t_env_class == self._j_tenv.getClass():
-            raise NotImplementedError("The operation 'from_elements' in batch mode is currently "
-                                      "not supported when using blink planner.")
+        is_blink = (blink_t_env_class == self._j_tenv.getClass())
+        if is_blink:
+            # we can not get ExecutionConfig object from the TableEnvironmentImpl
+            # for the moment, just create a new ExecutionConfig.
+            execution_config = gateway.jvm.org.apache.flink.api.common.ExecutionConfig()
         else:
-            jds = gateway.jvm.PythonBridgeUtils.createDataSetFromFile(
-                self._j_tenv.execEnv(), filename, True)
-            return Table(gateway.jvm.PythonTableUtils.fromDataSet(
-                self._j_tenv, jds, _to_java_type(schema)))
+            execution_config = self._j_tenv.execEnv().getConfig()
+
+        return execution_config
 
     def get_config(self):
         """
diff --git a/flink-python/pyflink/table/tests/test_calc.py b/flink-python/pyflink/table/tests/test_calc.py
index b67e083..6699efd 100644
--- a/flink-python/pyflink/table/tests/test_calc.py
+++ b/flink-python/pyflink/table/tests/test_calc.py
@@ -20,7 +20,7 @@ import array
 import datetime
 from decimal import Decimal
 
-from pyflink.table import DataTypes, Row
+from pyflink.table import DataTypes, Row, BatchTableEnvironment, EnvironmentSettings
 from pyflink.table.tests.test_types import ExamplePoint, PythonOnlyPoint, ExamplePointUDT, \
     PythonOnlyUDT
 from pyflink.testing import source_sink_utils
@@ -97,14 +97,60 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
               PythonOnlyPoint(3.0, 4.0))],
             schema)
         t.insert_into("Results")
-        self.t_env.execute("test")
+        t_env.execute("test")
         actual = source_sink_utils.results()
 
         expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
-                    '1970-01-02 00:00:00.0,86400000010,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
+                    '1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
                     '1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]']
         self.assert_equals(actual, expected)
 
+    def test_blink_from_element(self):
+        t_env = BatchTableEnvironment.create(environment_settings=EnvironmentSettings
+                                             .new_instance().use_blink_planner()
+                                             .in_batch_mode().build())
+        field_names = ["a", "b", "c", "d", "e", "f", "g", "h",
+                       "i", "j", "k", "l", "m", "n", "o", "p", "q", "r"]
+        field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(),
+                       DataTypes.STRING(), DataTypes.DATE(),
+                       DataTypes.TIME(),
+                       DataTypes.TIMESTAMP(),
+                       DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(),
+                       DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND()),
+                       DataTypes.ARRAY(DataTypes.DOUBLE()),
+                       DataTypes.ARRAY(DataTypes.DOUBLE(False)),
+                       DataTypes.ARRAY(DataTypes.STRING()),
+                       DataTypes.ARRAY(DataTypes.DATE()),
+                       DataTypes.DECIMAL(10, 0),
+                       DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()),
+                                      DataTypes.FIELD("b", DataTypes.DOUBLE())]),
+                       DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()),
+                       DataTypes.BYTES(),
+                       PythonOnlyUDT()]
+        schema = DataTypes.ROW(
+            list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type),
+                 field_names,
+                 field_types)))
+        table_sink = source_sink_utils.TestAppendSink(field_names, field_types)
+        t_env.register_table_sink("Results", table_sink)
+        t = t_env.from_elements(
+            [(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0),
+              datetime.datetime(1970, 1, 2, 0, 0), datetime.datetime(1970, 1, 2, 0, 0),
+              datetime.timedelta(days=1, microseconds=10),
+              [1.0, None], array.array("d", [1.0, 2.0]),
+              ["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0),
+              {"key": 1.0}, bytearray(b'ABCD'),
+              PythonOnlyPoint(3.0, 4.0))],
+            schema)
+        t.insert_into("Results")
+        t_env.execute("test")
+        actual = source_sink_utils.results()
+
+        expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
+                    '1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
+                    '1.000000000000000000,1,2.0,{key=1.0},[65, 66, 67, 68],[3.0, 4.0]']
+        self.assert_equals(actual, expected)
+
 
 if __name__ == '__main__':
     import unittest
diff --git a/flink-python/pyflink/table/tests/test_types.py b/flink-python/pyflink/table/tests/test_types.py
index 8972fa8..9b1cbf7 100644
--- a/flink-python/pyflink/table/tests/test_types.py
+++ b/flink-python/pyflink/table/tests/test_types.py
@@ -21,8 +21,11 @@ import ctypes
 import datetime
 import pickle
 import sys
+import tempfile
 import unittest
 
+from pyflink.serializers import BatchedSerializer, PickleSerializer
+
 from pyflink.java_gateway import get_gateway
 from pyflink.table.types import (_infer_schema_from_data, _infer_type,
                                  _array_signed_int_typecode_ctype_mappings,
@@ -825,10 +828,26 @@ class DataTypeConvertTests(unittest.TestCase):
                     DataTypes.DECIMAL(20, 10, False)]
         self.assertEqual(converted_python_types, expected)
 
+        # Legacy type tests
+        Types = gateway.jvm.org.apache.flink.table.api.Types
+        BlinkBigDecimalTypeInfo = \
+            gateway.jvm.org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
+
+        java_types = [Types.STRING(),
+                      Types.DECIMAL(),
+                      BlinkBigDecimalTypeInfo(12, 5)]
+
+        converted_python_types = [_from_java_type(item) for item in java_types]
+
+        expected = [DataTypes.VARCHAR(2147483647),
+                    DataTypes.DECIMAL(10, 0),
+                    DataTypes.DECIMAL(12, 5)]
+        self.assertEqual(converted_python_types, expected)
+
     def test_array_type(self):
+        # nullable/not_null flag will be lost during the conversion.
         test_types = [DataTypes.ARRAY(DataTypes.BIGINT()),
-                      # array type with not null basic data type means primitive array
-                      DataTypes.ARRAY(DataTypes.BIGINT().not_null()),
+                      DataTypes.ARRAY(DataTypes.BIGINT()),
                       DataTypes.ARRAY(DataTypes.STRING()),
                       DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT())),
                       DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))]
@@ -879,6 +898,41 @@ class DataTypeConvertTests(unittest.TestCase):
         self.assertEqual(test_types, converted_python_types)
 
 
+class DataSerializerTests(unittest.TestCase):
+
+    def test_java_pickle_deserializer(self):
+        temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
+        serializer = PickleSerializer()
+        data = [(1, 2), (3, 4), (5, 6), (7, 8)]
+
+        try:
+            serializer.dump_to_stream(data, temp_file)
+        finally:
+            temp_file.close()
+
+        gateway = get_gateway()
+        result = [tuple(int_pair) for int_pair in
+                  list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, False))]
+
+        self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
+
+    def test_java_batch_deserializer(self):
+        temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
+        serializer = BatchedSerializer(PickleSerializer(), 2)
+        data = [(1, 2), (3, 4), (5, 6), (7, 8)]
+
+        try:
+            serializer.dump_to_stream(data, temp_file)
+        finally:
+            temp_file.close()
+
+        gateway = get_gateway()
+        result = [tuple(int_pair) for int_pair in
+                  list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True))]
+
+        self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
+
+
 if __name__ == "__main__":
     try:
         import xmlrunner
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index af48049..dbf92f8 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -1665,17 +1665,7 @@ def _to_java_type(data_type):
 
     # ArrayType
     elif isinstance(data_type, ArrayType):
-        if type(data_type.element_type) in _primitive_array_element_types:
-            if data_type.element_type._nullable is False:
-                return Types.PRIMITIVE_ARRAY(_to_java_type(data_type.element_type))
-            else:
-                return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
-        elif isinstance(data_type.element_type, VarCharType) or isinstance(
-                data_type.element_type, CharType):
-            return gateway.jvm.org.apache.flink.api.common.typeinfo.\
-                BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
-        else:
-            return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
+        return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
 
     # MapType
     elif isinstance(data_type, MapType):
@@ -1783,8 +1773,15 @@ def _from_java_type(j_data_type):
             type_info = logical_type.getTypeInformation()
             BasicArrayTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.\
                 BasicArrayTypeInfo
+            BasicTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.BasicTypeInfo
             if type_info == BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO:
                 data_type = DataTypes.ARRAY(DataTypes.STRING())
+            elif type_info == BasicTypeInfo.BIG_DEC_TYPE_INFO:
+                data_type = DataTypes.DECIMAL(10, 0)
+            elif type_info.getClass() == \
+                get_java_class(gateway.jvm.org.apache.flink.table.runtime.typeutils
+                               .BigDecimalTypeInfo):
+                data_type = DataTypes.DECIMAL(type_info.precision(), type_info.scale())
             else:
                 raise TypeError("Unsupported type: %s, it is recognized as a legacy type."
                                 % type_info)
diff --git a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
index a082a8b..44a568b 100644
--- a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
@@ -17,16 +17,8 @@
 
 package org.apache.flink.api.common.python;
 
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.python.pickle.ArrayConstructor;
 import org.apache.flink.api.common.python.pickle.ByteArrayConstructor;
-import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.util.Collector;
 
 import net.razorvine.pickle.Unpickler;
 
@@ -39,38 +31,46 @@ import java.util.LinkedList;
 import java.util.List;
 
 /**
- * Utility class that contains helper methods to create a DataStream/DataSet from
+ * Utility class that contains helper methods to create a TableSource from
  * a file which contains Python objects.
  */
 public final class PythonBridgeUtils {
 
-	/**
-	 * Creates a DataStream from a file which contains serialized python objects.
-	 */
-	public static DataStream<Object[]> createDataStreamFromFile(
-		final StreamExecutionEnvironment streamExecutionEnvironment,
-		final String fileName,
-		final boolean batched) throws IOException {
-		return streamExecutionEnvironment
-			.fromCollection(readPythonObjects(fileName))
-			.flatMap(new PythonFlatMapFunction(batched))
-			.returns(Types.GENERIC(Object[].class));
+	private static Object[] getObjectArrayFromUnpickledData(Object input) {
+		if (input.getClass().isArray()) {
+			return (Object[]) input;
+		} else {
+			return ((ArrayList<Object>) input).toArray(new Object[0]);
+		}
 	}
 
-	/**
-	 * Creates a DataSet from a file which contains serialized python objects.
-	 */
-	public static DataSet<Object[]> createDataSetFromFile(
-		final ExecutionEnvironment executionEnvironment,
-		final String fileName,
-		final boolean batched) throws IOException {
-		return executionEnvironment
-			.fromCollection(readPythonObjects(fileName))
-			.flatMap(new PythonFlatMapFunction(batched))
-			.returns(Types.GENERIC(Object[].class));
+	public static List<Object[]> readPythonObjects(String fileName, boolean batched)
+		throws IOException {
+		List<byte[]> data = readPickledBytes(fileName);
+		Unpickler unpickle = new Unpickler();
+		initialize();
+		List<Object[]> unpickledData = new ArrayList<>();
+		for (byte[] pickledData: data) {
+			Object obj = unpickle.loads(pickledData);
+			if (batched) {
+				if (obj instanceof Object[]) {
+					Object[] arrayObj = (Object[]) obj;
+					for (Object o : arrayObj) {
+						unpickledData.add(getObjectArrayFromUnpickledData(o));
+					}
+				} else {
+					for (Object o : (ArrayList<Object>) obj) {
+						unpickledData.add(getObjectArrayFromUnpickledData(o));
+					}
+				}
+			} else {
+				unpickledData.add(getObjectArrayFromUnpickledData(obj));
+			}
+		}
+		return unpickledData;
 	}
 
-	private static List<byte[]> readPythonObjects(final String fileName) throws IOException {
+	private static List<byte[]> readPickledBytes(final String fileName) throws IOException {
 		List<byte[]> objs = new LinkedList<>();
 		try (DataInputStream din = new DataInputStream(new FileInputStream(fileName))) {
 			try {
@@ -87,50 +87,6 @@ public final class PythonBridgeUtils {
 		return objs;
 	}
 
-	private static final class PythonFlatMapFunction extends RichFlatMapFunction<byte[], Object[]> {
-
-		private static final long serialVersionUID = 1L;
-
-		private final boolean batched;
-		private transient Unpickler unpickle;
-
-		PythonFlatMapFunction(boolean batched) {
-			this.batched = batched;
-			initialize();
-		}
-
-		@Override
-		public void open(Configuration parameters) {
-			this.unpickle = new Unpickler();
-		}
-
-		@Override
-		public void flatMap(byte[] value, Collector<Object[]> out) throws Exception {
-			Object obj = unpickle.loads(value);
-			if (batched) {
-				if (obj instanceof Object[]) {
-					for (int i = 0; i < ((Object[]) obj).length; i++) {
-						collect(out, ((Object[]) obj)[i]);
-					}
-				} else {
-					for (Object o : (ArrayList<Object>) obj) {
-						collect(out, o);
-					}
-				}
-			} else {
-				collect(out, obj);
-			}
-		}
-
-		private void collect(Collector<Object[]> out, Object obj) {
-			if (obj.getClass().isArray()) {
-				out.collect((Object[]) obj);
-			} else {
-				out.collect(((ArrayList<Object>) obj).toArray(new Object[0]));
-			}
-		}
-	}
-
 	private static boolean initialized = false;
 	private static void initialize() {
 		synchronized (PythonBridgeUtils.class) {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
index 1ac2423..094945c 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
@@ -24,61 +24,36 @@ import java.time.{LocalDate, LocalDateTime, LocalTime}
 import java.util.TimeZone
 import java.util.function.BiConsumer
 
-import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.io.InputFormat
 import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation}
-import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.io.CollectionInputFormat
 import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
-import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.java.{BatchTableEnvironment, StreamTableEnvironment}
-import org.apache.flink.table.api.{Table, Types}
+import org.apache.flink.core.io.InputSplit
+import org.apache.flink.table.api.{TableSchema, Types}
+import org.apache.flink.table.sources.InputFormatTableSource
 import org.apache.flink.types.Row
 
+import scala.collection.JavaConversions._
+
 object PythonTableUtils {
 
   /**
-    * Converts the given [[DataStream]] into a [[Table]].
-    *
-    * The schema of the [[Table]] is derived from the specified schemaString.
+    * Wrap the unpickled python data with an InputFormat. It will be passed to
+    * PythonInputFormatTableSource later.
     *
-    * @param tableEnv The table environment.
-    * @param dataStream The [[DataStream]] to be converted.
-    * @param dataType The type information of the table.
-    * @return The converted [[Table]].
+    * @param data The unpickled python data.
+    * @param dataType The python data type.
+    * @param config The execution config used to create serializer.
+    * @return An InputFormat containing the python data.
     */
-  def fromDataStream(
-      tableEnv: StreamTableEnvironment,
-      dataStream: DataStream[Array[Object]],
-      dataType: TypeInformation[Row]): Table = {
-    val convertedDataStream = dataStream.map(
-      new MapFunction[Array[Object], Row] {
-        override def map(value: Array[Object]): Row =
-          convertTo(dataType).apply(value).asInstanceOf[Row]
-      }).returns(dataType.asInstanceOf[TypeInformation[Row]])
-
-    tableEnv.fromDataStream(convertedDataStream)
-  }
-
-  /**
-    * Converts the given [[DataSet]] into a [[Table]].
-    *
-    * The schema of the [[Table]] is derived from the specified schemaString.
-    *
-    * @param tableEnv The table environment.
-    * @param dataSet The [[DataSet]] to be converted.
-    * @param dataType The type information of the table.
-    * @return The converted [[Table]].
-    */
-  def fromDataSet(
-      tableEnv: BatchTableEnvironment,
-      dataSet: DataSet[Array[Object]],
-      dataType: TypeInformation[Row]): Table = {
-    val convertedDataSet = dataSet.map(
-      new MapFunction[Array[Object], Row] {
-        override def map(value: Array[Object]): Row =
-          convertTo(dataType).apply(value).asInstanceOf[Row]
-      }).returns(dataType.asInstanceOf[TypeInformation[Row]])
-
-    tableEnv.fromDataSet(convertedDataSet)
+  def getInputFormat(
+      data: java.util.List[Array[Object]],
+      dataType: TypeInformation[Row],
+      config: ExecutionConfig): InputFormat[Row, _] = {
+    val converter = convertTo(dataType)
+    new CollectionInputFormat(data.map(converter(_).asInstanceOf[Row]),
+      dataType.createSerializer(config))
   }
 
   /**
@@ -422,3 +397,23 @@ object PythonTableUtils {
     result
   }
 }
+
+/**
+  * An InputFormatTableSource created by python 'from_element' method.
+  *
+  * @param inputFormat The input format which contains the python data collection,
+  *                    usually created by PythonTableUtils#getInputFormat method
+  * @param rowTypeInfo The row type info of the python data.
+  *                    It is generated by the python 'from_element' method.
+  */
+class PythonInputFormatTableSource[Row](
+    inputFormat: InputFormat[Row, _ <: InputSplit],
+    rowTypeInfo: RowTypeInfo
+) extends InputFormatTableSource[Row] {
+
+  override def getInputFormat: InputFormat[Row, _ <: InputSplit] = inputFormat
+
+  override def getTableSchema: TableSchema = TableSchema.fromTypeInfo(rowTypeInfo)
+
+  override def getReturnType: TypeInformation[Row] = rowTypeInfo.asInstanceOf[TypeInformation[Row]]
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
index dbfdb14..cf3fcbf 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
@@ -27,7 +27,7 @@ import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.TimeCharacteristic
-import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink}
 import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.streaming.api.functions.sink.SinkFunction
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
@@ -621,7 +621,11 @@ private[flink] class TestAppendSink extends AppendStreamTableSink[Row] {
   var fTypes: Array[TypeInformation[_]] = _
 
   override def emitDataStream(s: DataStream[Row]): Unit = {
-    s.map(
+    consumeDataStream(s)
+  }
+
+  override def consumeDataStream(dataStream: DataStream[Row]): DataStreamSink[_] = {
+    dataStream.map(
       new MapFunction[Row, JTuple2[JBool, Row]] {
         override def map(value: Row): JTuple2[JBool, Row] = new JTuple2(true, value)
       })