You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2020/04/27 07:28:15 UTC

[GitHub] [flink] hequn8128 commented on a change in pull request #11832: [FLINK-17148][python] Support converting pandas DataFrame to Flink Table

hequn8128 commented on a change in pull request #11832:
URL: https://github.com/apache/flink/pull/11832#discussion_r415254771



##########
File path: flink-python/pyflink/table/tests/test_pandas_conversion.py
##########
@@ -0,0 +1,147 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import datetime
+import decimal
+
+from pyflink.table.types import DataTypes, Row
+from pyflink.testing import source_sink_utils
+from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
+    PyFlinkBlinkStreamTableTestCase, PyFlinkStreamTableTestCase
+
+
+class PandasConversionTestBase(object):
+
+    @classmethod
+    def setUpClass(cls):
+        super(PandasConversionTestBase, cls).setUpClass()
+        cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"),
+                     decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13),
+                     datetime.time(hour=1, minute=0, second=1),
+                     datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
+                     Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
+                         d=[1, 2])),
+                    (2, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"),
+                     decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13),
+                     datetime.time(hour=1, minute=0, second=1),
+                     datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
+                     Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
+                         d=[1, 2]))]
+        cls.data_type = DataTypes.ROW(
+            [DataTypes.FIELD("f1", DataTypes.TINYINT()),
+             DataTypes.FIELD("f2", DataTypes.SMALLINT()),
+             DataTypes.FIELD("f3", DataTypes.INT()),
+             DataTypes.FIELD("f4", DataTypes.BIGINT()),
+             DataTypes.FIELD("f5", DataTypes.BOOLEAN()),
+             DataTypes.FIELD("f6", DataTypes.FLOAT()),
+             DataTypes.FIELD("f7", DataTypes.DOUBLE()),
+             DataTypes.FIELD("f8", DataTypes.STRING()),
+             DataTypes.FIELD("f9", DataTypes.BYTES()),
+             DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)),
+             DataTypes.FIELD("f11", DataTypes.DATE()),
+             DataTypes.FIELD("f12", DataTypes.TIME()),
+             DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)),
+             DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())),
+             DataTypes.FIELD("f15", DataTypes.ROW(
+                 [DataTypes.FIELD("a", DataTypes.INT()),
+                  DataTypes.FIELD("b", DataTypes.STRING()),
+                  DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)),
+                  DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))])
+        cls.pdf = cls.create_pandas_data_frame()
+
+    @classmethod
+    def create_pandas_data_frame(cls):
+        data_dict = {}
+        for j, name in enumerate(cls.data_type.names):
+            data_dict[name] = [cls.data[i][j] for i in range(len(cls.data))]
+        # need convert to numpy types
+        import numpy as np
+        data_dict["f1"] = np.int8(data_dict["f1"])
+        data_dict["f2"] = np.int16(data_dict["f2"])
+        data_dict["f3"] = np.int32(data_dict["f3"])
+        data_dict["f4"] = np.int64(data_dict["f4"])
+        data_dict["f6"] = np.float32(data_dict["f6"])
+        data_dict["f7"] = np.float64(data_dict["f7"])
+        data_dict["f15"] = [row.as_dict() for row in data_dict["f15"]]
+        import pandas as pd
+        return pd.DataFrame(data=data_dict,
+                            columns=['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
+                                     'f10', 'f11', 'f12', 'f13', 'f14', 'f15'])
+
+
+class PandasConversionTests(PandasConversionTestBase):
+
+    def test_from_pandas_with_incorrect_schema(self):
+        fields = self.data_type.fields.copy()
+        fields[0], fields[7] = fields[7], fields[0]  # swap str with tinyint
+        wrong_schema = DataTypes.ROW(fields)  # should be DataTypes.STRING()
+        with self.assertRaisesRegex(Exception, "Expected a string.*got int8"):
+            self.t_env.from_pandas(self.pdf, schema=wrong_schema)
+
+    def test_from_pandas_with_names(self):
+        # skip decimal as currently only decimal(38, 18) is supported
+        pdf = self.pdf.drop(['f10', 'f11', 'f12', 'f13', 'f14', 'f15'], axis=1)
+        new_names = list(map(str, range(len(pdf.columns))))
+        table = self.t_env.from_pandas(pdf, schema=new_names)
+        self.assertEqual(new_names, table.get_schema().get_field_names())
+        table = self.t_env.from_pandas(pdf, schema=tuple(new_names))
+        self.assertEqual(new_names, table.get_schema().get_field_names())
+
+    def test_from_pandas_with_types(self):
+        new_types = self.data_type.field_types()
+        new_types[0] = DataTypes.BIGINT()
+        table = self.t_env.from_pandas(self.pdf, schema=new_types)
+        self.assertEqual(new_types, table.get_schema().get_field_data_types())
+        table = self.t_env.from_pandas(self.pdf, schema=tuple(new_types))
+        self.assertEqual(new_types, table.get_schema().get_field_data_types())
+
+
+class PandasConversionITTests(PandasConversionTestBase):
+
+    def test_from_pandas(self):
+        table = self.t_env.from_pandas(self.pdf, self.data_type, 5)
+        self.assertEqual(self.data_type, table.get_schema().to_row_data_type())
+
+        table = table.filter("f1 < 2")
+        table_sink = source_sink_utils.TestAppendSink(
+            self.data_type.field_names(),
+            self.data_type.field_types())
+        self.t_env.register_table_sink("Results", table_sink)
+        table.insert_into("Results")
+        self.t_env.execute("test")
+        actual = source_sink_utils.results()
+        self.assert_equals(actual,
+                           ["1,1,1,1,true,1.1,1.2,hello,[97, 97, 97],"
+                            "1000000000000000000.010000000000000000,2014-09-13,01:00:01,"
+                            "1970-01-01 00:00:00.123,[hello, 中文],1,hello,"
+                            "1970-01-01 00:00:00.123,[1, 2]"])
+
+
+class StreamPandasConversionTests(PandasConversionITTests,

Review comment:
       Can we also cover the batch mode for the old planner? 

##########
File path: flink-python/pyflink/table/tests/test_pandas_conversion.py
##########
@@ -0,0 +1,147 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import datetime
+import decimal
+
+from pyflink.table.types import DataTypes, Row
+from pyflink.testing import source_sink_utils
+from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
+    PyFlinkBlinkStreamTableTestCase, PyFlinkStreamTableTestCase
+
+
+class PandasConversionTestBase(object):
+
+    @classmethod
+    def setUpClass(cls):

Review comment:
       I found we should use lowercase for these test methods. However, it is not related to this PR. Maybe we can create another jira to address the problem. 

##########
File path: flink-python/pyflink/table/tests/test_pandas_conversion.py
##########
@@ -0,0 +1,147 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import datetime
+import decimal
+
+from pyflink.table.types import DataTypes, Row
+from pyflink.testing import source_sink_utils
+from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
+    PyFlinkBlinkStreamTableTestCase, PyFlinkStreamTableTestCase
+
+
+class PandasConversionTestBase(object):
+
+    @classmethod
+    def setUpClass(cls):
+        super(PandasConversionTestBase, cls).setUpClass()
+        cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"),
+                     decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13),
+                     datetime.time(hour=1, minute=0, second=1),
+                     datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
+                     Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
+                         d=[1, 2])),
+                    (2, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"),
+                     decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13),
+                     datetime.time(hour=1, minute=0, second=1),
+                     datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
+                     Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
+                         d=[1, 2]))]
+        cls.data_type = DataTypes.ROW(
+            [DataTypes.FIELD("f1", DataTypes.TINYINT()),
+             DataTypes.FIELD("f2", DataTypes.SMALLINT()),
+             DataTypes.FIELD("f3", DataTypes.INT()),
+             DataTypes.FIELD("f4", DataTypes.BIGINT()),
+             DataTypes.FIELD("f5", DataTypes.BOOLEAN()),
+             DataTypes.FIELD("f6", DataTypes.FLOAT()),
+             DataTypes.FIELD("f7", DataTypes.DOUBLE()),
+             DataTypes.FIELD("f8", DataTypes.STRING()),
+             DataTypes.FIELD("f9", DataTypes.BYTES()),
+             DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)),
+             DataTypes.FIELD("f11", DataTypes.DATE()),
+             DataTypes.FIELD("f12", DataTypes.TIME()),
+             DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)),
+             DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())),
+             DataTypes.FIELD("f15", DataTypes.ROW(
+                 [DataTypes.FIELD("a", DataTypes.INT()),
+                  DataTypes.FIELD("b", DataTypes.STRING()),
+                  DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)),
+                  DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))])
+        cls.pdf = cls.create_pandas_data_frame()
+
+    @classmethod
+    def create_pandas_data_frame(cls):
+        data_dict = {}
+        for j, name in enumerate(cls.data_type.names):
+            data_dict[name] = [cls.data[i][j] for i in range(len(cls.data))]
+        # need convert to numpy types

Review comment:
       Why we need to convert to NumPy types?

##########
File path: flink-python/src/main/java/org/apache/flink/table/runtime/arrow/sources/AbstractArrowSourceFunction.java
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.sources;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.runtime.arrow.ArrowReader;
+import org.apache.flink.table.runtime.arrow.ArrowUtils;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+/**
+ * An Arrow {@link SourceFunction} which takes the serialized arrow record batch data as input.
+ *
+ * @param <OUT> The type of the records produced by this source.
+ */
+@Internal
+public abstract class AbstractArrowSourceFunction<OUT>
+		extends RichParallelSourceFunction<OUT>
+		implements ResultTypeQueryable<OUT>, CheckpointedFunction {
+
+	private static final long serialVersionUID = 1L;
+
+	static {
+		ArrowUtils.checkArrowUsable();
+	}
+
+	/**
+	 * The type of the records produced by this source.
+	 */
+	final DataType dataType;
+
+	/**
+	 * The array of byte array of the source data. Each element is an array
+	 * representing an arrow batch.
+	 */
+	private final byte[][] arrowData;
+
+	/**
+	 * Allocator which is used for byte buffer allocation.
+	 */
+	private transient BufferAllocator allocator;
+
+	/**
+	 * Container that holds a set of vectors for the source data to emit.
+	 */
+	private transient VectorSchemaRoot root;
+
+	private transient volatile boolean running;
+
+	/**
+	 * The indexes of the collection of source data to emit. Each element is a tuple of
+	 * the index of the arrow batch and the staring index inside the arrow batch.
+	 */
+	private transient Deque<Tuple2<Integer, Integer>> indexesToEmit;
+
+	/**
+	 * The indexes of the source data which have not been emitted.
+	 */
+	private transient ListState<Tuple2<Integer, Integer>> checkpointedState;
+
+	AbstractArrowSourceFunction(DataType dataType, byte[][] arrowData) {
+		this.dataType = Preconditions.checkNotNull(dataType);
+		this.arrowData = Preconditions.checkNotNull(arrowData);
+	}
+
+	@Override
+	public void open(Configuration parameters) throws Exception {
+		allocator = ArrowUtils.getRootAllocator().newChildAllocator("ArrowSourceFunction", 0, Long.MAX_VALUE);
+		root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema((RowType) dataType.getLogicalType()), allocator);
+		running = true;
+	}
+
+	@Override
+	public void close() throws Exception {
+		try {
+			super.close();
+		} finally {
+			if (root != null) {
+				root.close();
+				root = null;
+			}
+			if (allocator != null) {
+				allocator.close();
+				allocator = null;
+			}
+		}
+	}
+
+	@Override
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState == null,
+			"The " + getClass().getSimpleName() + " has already been initialized.");
+
+		this.checkpointedState = context.getOperatorStateStore().getListState(
+			new ListStateDescriptor<>(
+				"arrow-source-state",
+				new TupleSerializer<>(
+					(Class<Tuple2<Integer, Integer>>) (Class<?>) Tuple2.class,
+					new TypeSerializer[]{IntSerializer.INSTANCE, IntSerializer.INSTANCE})
+			)
+		);
+
+		this.indexesToEmit = new ArrayDeque<>();
+		if (context.isRestored()) {
+			// upon restoring
+			for (Tuple2<Integer, Integer> v : this.checkpointedState.get()) {
+				this.indexesToEmit.add(v);
+			}
+		} else {
+			// the first time the job is executed
+			final int stepSize = getRuntimeContext().getNumberOfParallelSubtasks();
+			final int taskIdx = getRuntimeContext().getIndexOfThisSubtask();
+
+			for (int i = taskIdx; i < arrowData.length; i += stepSize) {
+				this.indexesToEmit.add(Tuple2.of(i, 0));
+			}
+		}
+	}
+
+	@Override
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState != null,
+			"The " + getClass().getSimpleName() + " state has not been properly initialized.");
+
+		this.checkpointedState.clear();
+		for (Tuple2<Integer, Integer> v : indexesToEmit) {
+			this.checkpointedState.add(v);
+		}
+	}
+
+	@Override
+	public void run(SourceContext<OUT> ctx) throws Exception {
+		VectorLoader vectorLoader = new VectorLoader(root);
+		while (running && !indexesToEmit.isEmpty()) {
+			Tuple2<Integer, Integer> indexToEmit = indexesToEmit.peek();
+			ArrowRecordBatch arrowRecordBatch = loadBatch(indexToEmit.f0);
+			vectorLoader.load(arrowRecordBatch);
+			arrowRecordBatch.close();
+
+			ArrowReader<OUT> arrowReader = createArrowReader(root);
+			int rowCount = root.getRowCount();
+			int nextRowId = indexToEmit.f1;
+			while (nextRowId < rowCount) {
+				OUT element = arrowReader.read(nextRowId);
+				synchronized (ctx.getCheckpointLock()) {
+					ctx.collect(element);
+					indexToEmit.setField(++nextRowId, 1);
+				}
+			}
+
+			synchronized (ctx.getCheckpointLock()) {
+				indexesToEmit.pop();
+			}
+		}
+	}
+
+	@Override
+	public void cancel() {
+		running = false;
+	}
+
+	public abstract ArrowReader<OUT> createArrowReader(VectorSchemaRoot root);

Review comment:
       protected

##########
File path: flink-python/pyflink/table/table_environment.py
##########
@@ -1107,6 +1107,63 @@ def _from_elements(self, elements, schema):
         finally:
             os.unlink(temp_file.name)
 
+    def from_pandas(self, pdf,

Review comment:
       Add detailed python docs for the API.
   BTW, do we plan to add Flink document for this API in another PR? If so, we can first create a jira to address it under FLINK-17146

##########
File path: flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/ArrowSourceFunctionTestBase.java
##########
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.sources;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.runtime.arrow.ArrowUtils;
+import org.apache.flink.table.runtime.arrow.ArrowWriter;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Abstract test base for the Arrow source function processing.
+ */
+public abstract class ArrowSourceFunctionTestBase<T> {
+
+	static DataType dataType;
+	private static BufferAllocator allocator;
+
+	@BeforeClass
+	public static void init() {
+		dataType = DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.STRING()));
+		allocator = ArrowUtils.getRootAllocator().newChildAllocator("stdout", 0, Long.MAX_VALUE);
+	}
+
+	@Test
+	public void testRestore() throws Exception {
+		Tuple2<List<T>, Integer> testData = getTestData();
+		final AbstractArrowSourceFunction<T> arrowSourceFunction =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+
+		final AbstractStreamOperatorTestHarness<T> testHarness =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction), 1, 1, 0);
+		testHarness.open();
+
+		final Throwable[] error = new Throwable[1];
+		final MultiShotLatch latch = new MultiShotLatch();
+		final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
+
+		final DummySourceContext<T> sourceContext = new DummySourceContext<T>() {
+			@Override
+			public void collect(T element) {
+				if (numOfEmittedElements.get() == 2) {
+					latch.trigger();
+					// fail the source function at the the second element
+					throw new RuntimeException("Fail the arrow source");
+				}
+				numOfEmittedElements.incrementAndGet();
+			}
+		};
+
+		// run the source asynchronously
+		Thread runner = new Thread(() -> {
+			try {
+				arrowSourceFunction.run(sourceContext);
+			} catch (Throwable t) {
+				if (!t.getMessage().equals("Fail the arrow source")) {
+					error[0] = t;
+				}
+			}
+		});
+		runner.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+
+		OperatorSubtaskState snapshot;
+		synchronized (sourceContext.getCheckpointLock()) {
+			snapshot = testHarness.snapshot(0, 0);
+		}
+
+		runner.join();
+		testHarness.close();
+
+		final AbstractArrowSourceFunction<T> arrowSourceFunction2 =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+		AbstractStreamOperatorTestHarness<T> testHarnessCopy =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction2), 1, 1, 0);
+		testHarnessCopy.initializeState(snapshot);
+		testHarnessCopy.open();
+
+		// run the source asynchronously
+		Thread runner2 = new Thread(() -> {
+			try {
+				arrowSourceFunction2.run(new DummySourceContext<T>() {
+					@Override
+					public void collect(T element) {
+						if (numOfEmittedElements.incrementAndGet() == testData.f0.size()) {
+							latch.trigger();
+						}
+					}
+				});
+			} catch (Throwable t) {
+				error[0] = t;
+			}
+		});
+		runner2.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+		runner2.join();
+
+		Assert.assertNull(error[0]);
+		Assert.assertEquals(testData.f0.size(), numOfEmittedElements.get());

Review comment:
       Also verify the content of the data?

##########
File path: flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/ArrowSourceFunctionTestBase.java
##########
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.sources;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.runtime.arrow.ArrowUtils;
+import org.apache.flink.table.runtime.arrow.ArrowWriter;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Abstract test base for the Arrow source function processing.
+ */
+public abstract class ArrowSourceFunctionTestBase<T> {
+
+	static DataType dataType;
+	private static BufferAllocator allocator;
+
+	@BeforeClass
+	public static void init() {
+		dataType = DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.STRING()));
+		allocator = ArrowUtils.getRootAllocator().newChildAllocator("stdout", 0, Long.MAX_VALUE);
+	}
+
+	@Test
+	public void testRestore() throws Exception {
+		Tuple2<List<T>, Integer> testData = getTestData();
+		final AbstractArrowSourceFunction<T> arrowSourceFunction =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+
+		final AbstractStreamOperatorTestHarness<T> testHarness =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction), 1, 1, 0);
+		testHarness.open();
+
+		final Throwable[] error = new Throwable[1];
+		final MultiShotLatch latch = new MultiShotLatch();
+		final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
+
+		final DummySourceContext<T> sourceContext = new DummySourceContext<T>() {
+			@Override
+			public void collect(T element) {
+				if (numOfEmittedElements.get() == 2) {
+					latch.trigger();
+					// fail the source function at the the second element
+					throw new RuntimeException("Fail the arrow source");
+				}
+				numOfEmittedElements.incrementAndGet();
+			}
+		};
+
+		// run the source asynchronously
+		Thread runner = new Thread(() -> {
+			try {
+				arrowSourceFunction.run(sourceContext);
+			} catch (Throwable t) {
+				if (!t.getMessage().equals("Fail the arrow source")) {
+					error[0] = t;

Review comment:
       Add the corresponding assert to verify that error[0] is not null?

##########
File path: flink-python/src/main/java/org/apache/flink/table/runtime/arrow/sources/AbstractArrowSourceFunction.java
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.sources;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.runtime.arrow.ArrowReader;
+import org.apache.flink.table.runtime.arrow.ArrowUtils;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+/**
+ * An Arrow {@link SourceFunction} which takes the serialized arrow record batch data as input.
+ *
+ * @param <OUT> The type of the records produced by this source.
+ */
+@Internal
+public abstract class AbstractArrowSourceFunction<OUT>
+		extends RichParallelSourceFunction<OUT>
+		implements ResultTypeQueryable<OUT>, CheckpointedFunction {
+
+	private static final long serialVersionUID = 1L;
+
+	static {
+		ArrowUtils.checkArrowUsable();
+	}
+
+	/**
+	 * The type of the records produced by this source.
+	 */
+	final DataType dataType;
+
+	/**
+	 * The array of byte array of the source data. Each element is an array
+	 * representing an arrow batch.
+	 */
+	private final byte[][] arrowData;
+
+	/**
+	 * Allocator which is used for byte buffer allocation.
+	 */
+	private transient BufferAllocator allocator;
+
+	/**
+	 * Container that holds a set of vectors for the source data to emit.
+	 */
+	private transient VectorSchemaRoot root;
+
+	private transient volatile boolean running;
+
+	/**
+	 * The indexes of the collection of source data to emit. Each element is a tuple of
+	 * the index of the arrow batch and the staring index inside the arrow batch.
+	 */
+	private transient Deque<Tuple2<Integer, Integer>> indexesToEmit;
+
+	/**
+	 * The indexes of the source data which have not been emitted.
+	 */
+	private transient ListState<Tuple2<Integer, Integer>> checkpointedState;
+
+	AbstractArrowSourceFunction(DataType dataType, byte[][] arrowData) {
+		this.dataType = Preconditions.checkNotNull(dataType);
+		this.arrowData = Preconditions.checkNotNull(arrowData);
+	}
+
+	@Override
+	public void open(Configuration parameters) throws Exception {
+		allocator = ArrowUtils.getRootAllocator().newChildAllocator("ArrowSourceFunction", 0, Long.MAX_VALUE);
+		root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema((RowType) dataType.getLogicalType()), allocator);
+		running = true;
+	}
+
+	@Override
+	public void close() throws Exception {
+		try {
+			super.close();
+		} finally {
+			if (root != null) {
+				root.close();
+				root = null;
+			}
+			if (allocator != null) {
+				allocator.close();
+				allocator = null;
+			}
+		}
+	}
+
+	@Override
+	public void initializeState(FunctionInitializationContext context) throws Exception {

Review comment:
       Maybe add some log in this method? For example, LOG.info the restored information. 

##########
File path: flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/ArrowSourceFunctionTestBase.java
##########
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.sources;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.runtime.arrow.ArrowUtils;
+import org.apache.flink.table.runtime.arrow.ArrowWriter;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Abstract test base for the Arrow source function processing.
+ */
+public abstract class ArrowSourceFunctionTestBase<T> {
+
+	static DataType dataType;
+	private static BufferAllocator allocator;
+
+	@BeforeClass
+	public static void init() {
+		dataType = DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.STRING()));
+		allocator = ArrowUtils.getRootAllocator().newChildAllocator("stdout", 0, Long.MAX_VALUE);
+	}
+
+	@Test
+	public void testRestore() throws Exception {
+		Tuple2<List<T>, Integer> testData = getTestData();
+		final AbstractArrowSourceFunction<T> arrowSourceFunction =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+
+		final AbstractStreamOperatorTestHarness<T> testHarness =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction), 1, 1, 0);
+		testHarness.open();
+
+		final Throwable[] error = new Throwable[1];
+		final MultiShotLatch latch = new MultiShotLatch();
+		final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
+
+		final DummySourceContext<T> sourceContext = new DummySourceContext<T>() {
+			@Override
+			public void collect(T element) {
+				if (numOfEmittedElements.get() == 2) {
+					latch.trigger();
+					// fail the source function at the the second element
+					throw new RuntimeException("Fail the arrow source");
+				}
+				numOfEmittedElements.incrementAndGet();
+			}
+		};
+
+		// run the source asynchronously
+		Thread runner = new Thread(() -> {
+			try {
+				arrowSourceFunction.run(sourceContext);
+			} catch (Throwable t) {
+				if (!t.getMessage().equals("Fail the arrow source")) {
+					error[0] = t;
+				}
+			}
+		});
+		runner.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+
+		OperatorSubtaskState snapshot;
+		synchronized (sourceContext.getCheckpointLock()) {
+			snapshot = testHarness.snapshot(0, 0);
+		}
+
+		runner.join();
+		testHarness.close();
+
+		final AbstractArrowSourceFunction<T> arrowSourceFunction2 =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+		AbstractStreamOperatorTestHarness<T> testHarnessCopy =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction2), 1, 1, 0);
+		testHarnessCopy.initializeState(snapshot);
+		testHarnessCopy.open();
+
+		// run the source asynchronously
+		Thread runner2 = new Thread(() -> {
+			try {
+				arrowSourceFunction2.run(new DummySourceContext<T>() {
+					@Override
+					public void collect(T element) {
+						if (numOfEmittedElements.incrementAndGet() == testData.f0.size()) {
+							latch.trigger();
+						}
+					}
+				});
+			} catch (Throwable t) {
+				error[0] = t;
+			}
+		});
+		runner2.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+		runner2.join();
+
+		Assert.assertNull(error[0]);
+		Assert.assertEquals(testData.f0.size(), numOfEmittedElements.get());
+	}
+
+	@Test
+	public void testParallelProcessing() throws Exception {
+		Tuple2<List<T>, Integer> testData = getTestData();
+		final AbstractArrowSourceFunction<T> arrowSourceFunction =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+
+		final AbstractStreamOperatorTestHarness<T> testHarness =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction), 2, 2, 0);
+		testHarness.open();
+
+		final Throwable[] error = new Throwable[2];
+		final OneShotLatch latch = new OneShotLatch();
+		final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
+
+		// run the source asynchronously
+		Thread runner = new Thread(() -> {
+			try {
+				arrowSourceFunction.run(new DummySourceContext<T>() {
+					@Override
+					public void collect(T element) {
+						if (numOfEmittedElements.incrementAndGet() == testData.f0.size()) {
+							latch.trigger();
+						}
+					}
+				});
+			} catch (Throwable t) {
+				error[0] = t;
+			}
+		});
+		runner.start();
+
+		final AbstractArrowSourceFunction<T> arrowSourceFunction2 =
+			createTestArrowSourceFunction(testData.f0, testData.f1);
+		final AbstractStreamOperatorTestHarness<T> testHarness2 =
+			new AbstractStreamOperatorTestHarness<>(new StreamSource<>(arrowSourceFunction2), 2, 2, 1);
+		testHarness2.open();
+
+		// run the source asynchronously
+		Thread runner2 = new Thread(() -> {
+			try {
+				arrowSourceFunction2.run(new DummySourceContext<T>() {
+					@Override
+					public void collect(T element) {
+						if (numOfEmittedElements.incrementAndGet() == testData.f0.size()) {
+							latch.trigger();
+						}
+					}
+				});
+			} catch (Throwable t) {
+				error[1] = t;
+			}
+		});
+		runner2.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+
+		runner.join();
+		runner2.join();
+		testHarness.close();
+		testHarness2.close();
+
+		Assert.assertNull(error[0]);
+		Assert.assertNull(error[1]);
+		Assert.assertEquals(testData.f0.size(), numOfEmittedElements.get());

Review comment:
       Also verify the content of the data?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org