You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2023/01/11 02:46:23 UTC

[flink] branch master updated (cfe794b792b -> afc8fb08ab0)

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


    from cfe794b792b [FLINK-30485][Connector/RabbitMQ] Remove RabbitMQ connector. This closes #21550
     new b781a13dd61 [FLINK-30607][python] Support MapType for Pandas UDF
     new 72c9cfdb6e8 [FLINK-30607][python] Support BinaryType for Pandas UDF
     new afc8fb08ab0 [FLINK-30607][python] Support NoneType for Pandas UDF

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 flink-python/pyflink/fn_execution/coders.py        |  11 +-
 .../pyflink/fn_execution/flink_fn_execution_pb2.py | 122 +++++++++----------
 .../pyflink/proto/flink-fn-execution.proto         |   1 +
 .../pyflink/table/tests/test_pandas_udf.py         |  59 ++++++++--
 flink-python/pyflink/table/types.py                |  12 ++
 .../flink/table/runtime/arrow/ArrowUtils.java      |  84 +++++++++++++
 ...umnVector.java => ArrowBinaryColumnVector.java} |  16 +--
 .../arrow/vectors/ArrowMapColumnVector.java        |  60 ++++++++++
 .../arrow/vectors/ArrowNullColumnVector.java       |  20 ++--
 .../{VarBinaryWriter.java => BinaryWriter.java}    |  38 +++---
 .../table/runtime/arrow/writers/MapWriter.java     | 130 +++++++++++++++++++++
 .../table/runtime/arrow/writers/NullWriter.java    |  16 ++-
 12 files changed, 454 insertions(+), 115 deletions(-)
 copy flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/{ArrowVarBinaryColumnVector.java => ArrowBinaryColumnVector.java} (72%)
 create mode 100644 flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
 copy flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingInputStreamDecorator.java => flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowNullColumnVector.java (67%)
 copy flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/{VarBinaryWriter.java => BinaryWriter.java} (60%)
 create mode 100644 flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
 copy flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/BeginStatementSetOperation.java => flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/NullWriter.java (73%)


[flink] 02/03: [FLINK-30607][python] Support BinaryType for Pandas UDF

Posted by di...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 72c9cfdb6e8324d9982bf7d05306c9dac98a0bfd
Author: Dian Fu <di...@apache.org>
AuthorDate: Tue Jan 10 16:00:15 2023 +0800

    [FLINK-30607][python] Support BinaryType for Pandas UDF
    
    This closes #21639.
---
 flink-python/pyflink/fn_execution/coders.py        |  5 +-
 .../pyflink/table/tests/test_pandas_udf.py         | 21 +++--
 flink-python/pyflink/table/types.py                |  2 +
 .../flink/table/runtime/arrow/ArrowUtils.java      | 15 ++++
 .../arrow/vectors/ArrowBinaryColumnVector.java     | 48 +++++++++++
 .../table/runtime/arrow/writers/BinaryWriter.java  | 95 ++++++++++++++++++++++
 6 files changed, 180 insertions(+), 6 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 704d278666b..190efb5dea6 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -29,7 +29,8 @@ from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, D
     ExternalTypeInfo
 from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
     FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
-    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType
+    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType, \
+    BinaryType
 
 try:
     from pyflink.fn_execution import coder_impl_fast as coder_impl
@@ -125,6 +126,8 @@ class LengthPrefixBaseCoder(ABC):
             return DoubleType(field_type.nullable)
         elif field_type.type_name == flink_fn_execution_pb2.Schema.VARCHAR:
             return VarCharType(0x7fffffff, field_type.nullable)
+        elif field_type.type_name == flink_fn_execution_pb2.Schema.BINARY:
+            return BinaryType(field_type.binary_info.length, field_type.nullable)
         elif field_type.type_name == flink_fn_execution_pb2.Schema.VARBINARY:
             return VarBinaryType(0x7fffffff, field_type.nullable)
         elif field_type.type_name == flink_fn_execution_pb2.Schema.DECIMAL:
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index 46d9560b351..0e4d9d9a8ab 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -220,6 +220,14 @@ class PandasUDFITTests(object):
             assert isinstance(map_param, pd.Series)
             return map_param
 
+        @udf(result_type=DataTypes.BINARY(5), func_type="pandas")
+        def binary_func(binary_param):
+            assert isinstance(binary_param, pd.Series)
+            assert isinstance(binary_param[0], bytes), \
+                'binary_param of wrong type %s !' % type(binary_param[0])
+            assert len(binary_param[0]) == 5
+            return binary_param
+
         sink_table_ddl = """
             CREATE TABLE Results_test_all_data_types(
                 a TINYINT,
@@ -243,7 +251,8 @@ class PandasUDFITTests(object):
                 s ARRAY<INT>,
                 t ARRAY<STRING>,
                 u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>,
-                v MAP<STRING, STRING>
+                v MAP<STRING, STRING>,
+                w BINARY(5)
             ) WITH ('connector'='test-sink')
         """
         self.t_env.execute_sql(sink_table_ddl)
@@ -255,7 +264,7 @@ class PandasUDFITTests(object):
               datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1),
               timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2],
               [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]),
-              {"1": "hello", "2": "world"})],
+              {"1": "hello", "2": "world"}, bytearray(b'flink'))],
             DataTypes.ROW(
                 [DataTypes.FIELD("a", DataTypes.TINYINT()),
                  DataTypes.FIELD("b", DataTypes.SMALLINT()),
@@ -278,7 +287,8 @@ class PandasUDFITTests(object):
                  DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())),
                  DataTypes.FIELD("t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))),
                  DataTypes.FIELD("u", row_type),
-                 DataTypes.FIELD("v", map_type)]))
+                 DataTypes.FIELD("v", map_type),
+                 DataTypes.FIELD("w", DataTypes.BINARY(5))]))
 
         t.select(
             tinyint_func(t.a),
@@ -302,7 +312,8 @@ class PandasUDFITTests(object):
             array_int_func(t.s),
             nested_array_func(t.t),
             row_func(t.u),
-            map_func(t.v)) \
+            map_func(t.v),
+            binary_func(t.k)) \
             .execute_insert("Results_test_all_data_types").wait()
         actual = source_sink_utils.results()
         self.assert_equals(
@@ -312,7 +323,7 @@ class PandasUDFITTests(object):
              "1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, "
              "1970-01-02T00:00:00.123, [hello, 中文, null], [1970-01-02T00:00:00.123], "
              "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]], "
-             "{1=hello, 2=world}]"])
+             "{1=hello, 2=world}, [102, 108, 105, 110, 107]]"])
 
     def test_invalid_pandas_udf(self):
 
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index 83d54629dad..c8227f66042 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2281,6 +2281,8 @@ def to_arrow_type(data_type: DataType):
         return pa.float64()
     elif isinstance(data_type, (CharType, VarCharType)):
         return pa.utf8()
+    elif isinstance(data_type, BinaryType):
+        return pa.binary(data_type.length)
     elif isinstance(data_type, VarBinaryType):
         return pa.binary()
     elif isinstance(data_type, DecimalType):
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
index 2d4e09e4553..bdd691b5c9d 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
@@ -35,6 +35,7 @@ import org.apache.flink.table.operations.OutputConversionModifyOperation;
 import org.apache.flink.table.runtime.arrow.sources.ArrowTableSource;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowArrayColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowBigIntColumnVector;
+import org.apache.flink.table.runtime.arrow.vectors.ArrowBinaryColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowBooleanColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowDateColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowDecimalColumnVector;
@@ -52,6 +53,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
 import org.apache.flink.table.runtime.arrow.writers.ArrayWriter;
 import org.apache.flink.table.runtime.arrow.writers.ArrowFieldWriter;
 import org.apache.flink.table.runtime.arrow.writers.BigIntWriter;
+import org.apache.flink.table.runtime.arrow.writers.BinaryWriter;
 import org.apache.flink.table.runtime.arrow.writers.BooleanWriter;
 import org.apache.flink.table.runtime.arrow.writers.DateWriter;
 import org.apache.flink.table.runtime.arrow.writers.DecimalWriter;
@@ -69,6 +71,7 @@ import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.logical.ArrayType;
 import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.BinaryType;
 import org.apache.flink.table.types.logical.BooleanType;
 import org.apache.flink.table.types.logical.CharType;
 import org.apache.flink.table.types.logical.DateType;
@@ -103,6 +106,7 @@ import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DateDayVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
@@ -255,6 +259,8 @@ public final class ArrowUtils {
             return DoubleWriter.forRow((Float8Vector) vector);
         } else if (vector instanceof VarCharVector) {
             return VarCharWriter.forRow((VarCharVector) vector);
+        } else if (vector instanceof FixedSizeBinaryVector) {
+            return BinaryWriter.forRow((FixedSizeBinaryVector) vector);
         } else if (vector instanceof VarBinaryVector) {
             return VarBinaryWriter.forRow((VarBinaryVector) vector);
         } else if (vector instanceof DecimalVector) {
@@ -328,6 +334,8 @@ public final class ArrowUtils {
             return DoubleWriter.forArray((Float8Vector) vector);
         } else if (vector instanceof VarCharVector) {
             return VarCharWriter.forArray((VarCharVector) vector);
+        } else if (vector instanceof FixedSizeBinaryVector) {
+            return BinaryWriter.forArray((FixedSizeBinaryVector) vector);
         } else if (vector instanceof VarBinaryVector) {
             return VarBinaryWriter.forArray((VarBinaryVector) vector);
         } else if (vector instanceof DecimalVector) {
@@ -411,6 +419,8 @@ public final class ArrowUtils {
             return new ArrowDoubleColumnVector((Float8Vector) vector);
         } else if (vector instanceof VarCharVector) {
             return new ArrowVarCharColumnVector((VarCharVector) vector);
+        } else if (vector instanceof FixedSizeBinaryVector) {
+            return new ArrowBinaryColumnVector((FixedSizeBinaryVector) vector);
         } else if (vector instanceof VarBinaryVector) {
             return new ArrowVarBinaryColumnVector((VarBinaryVector) vector);
         } else if (vector instanceof DecimalVector) {
@@ -715,6 +725,11 @@ public final class ArrowUtils {
             return ArrowType.Utf8.INSTANCE;
         }
 
+        @Override
+        public ArrowType visit(BinaryType varCharType) {
+            return new ArrowType.FixedSizeBinary(varCharType.getLength());
+        }
+
         @Override
         public ArrowType visit(VarBinaryType varCharType) {
             return ArrowType.Binary.INSTANCE;
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowBinaryColumnVector.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowBinaryColumnVector.java
new file mode 100644
index 00000000000..6a22011739a
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowBinaryColumnVector.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.vectors;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.columnar.vector.BytesColumnVector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.FixedSizeBinaryVector;
+
+/** Arrow column vector for Binary. */
+@Internal
+public final class ArrowBinaryColumnVector implements BytesColumnVector {
+
+    /** Container which is used to store the sequence of varbinary values of a column to read. */
+    private final FixedSizeBinaryVector fixedSizeBinaryVector;
+
+    public ArrowBinaryColumnVector(FixedSizeBinaryVector fixedSizeBinaryVector) {
+        this.fixedSizeBinaryVector = Preconditions.checkNotNull(fixedSizeBinaryVector);
+    }
+
+    @Override
+    public Bytes getBytes(int i) {
+        byte[] bytes = fixedSizeBinaryVector.get(i);
+        return new Bytes(bytes, 0, bytes.length);
+    }
+
+    @Override
+    public boolean isNullAt(int i) {
+        return fixedSizeBinaryVector.isNull(i);
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/BinaryWriter.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/BinaryWriter.java
new file mode 100644
index 00000000000..af9f4724a18
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/BinaryWriter.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.writers;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.RowData;
+
+import org.apache.arrow.vector.FixedSizeBinaryVector;
+
+/** {@link ArrowFieldWriter} for Binary. */
+@Internal
+public abstract class BinaryWriter<T> extends ArrowFieldWriter<T> {
+
+    public static BinaryWriter<RowData> forRow(FixedSizeBinaryVector fixedSizeBinaryVector) {
+        return new BinaryWriterForRow(fixedSizeBinaryVector);
+    }
+
+    public static BinaryWriter<ArrayData> forArray(FixedSizeBinaryVector fixedSizeBinaryVector) {
+        return new BinaryWriterForArray(fixedSizeBinaryVector);
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    private BinaryWriter(FixedSizeBinaryVector fixedSizeBinaryVector) {
+        super(fixedSizeBinaryVector);
+    }
+
+    abstract boolean isNullAt(T in, int ordinal);
+
+    abstract byte[] readBinary(T in, int ordinal);
+
+    @Override
+    public void doWrite(T in, int ordinal) {
+        if (isNullAt(in, ordinal)) {
+            ((FixedSizeBinaryVector) getValueVector()).setNull(getCount());
+        } else {
+            ((FixedSizeBinaryVector) getValueVector()).setSafe(getCount(), readBinary(in, ordinal));
+        }
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    /** {@link BinaryWriter} for {@link RowData} input. */
+    public static final class BinaryWriterForRow extends BinaryWriter<RowData> {
+
+        private BinaryWriterForRow(FixedSizeBinaryVector fixedSizeBinaryVector) {
+            super(fixedSizeBinaryVector);
+        }
+
+        @Override
+        boolean isNullAt(RowData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        byte[] readBinary(RowData in, int ordinal) {
+            return in.getBinary(ordinal);
+        }
+    }
+
+    /** {@link BinaryWriter} for {@link ArrayData} input. */
+    public static final class BinaryWriterForArray extends BinaryWriter<ArrayData> {
+
+        private BinaryWriterForArray(FixedSizeBinaryVector fixedSizeBinaryVector) {
+            super(fixedSizeBinaryVector);
+        }
+
+        @Override
+        boolean isNullAt(ArrayData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        byte[] readBinary(ArrayData in, int ordinal) {
+            return in.getBinary(ordinal);
+        }
+    }
+}


[flink] 03/03: [FLINK-30607][python] Support NoneType for Pandas UDF

Posted by di...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit afc8fb08ab0879537814d3c77372268eb6d6a4de
Author: Dian Fu <di...@apache.org>
AuthorDate: Tue Jan 10 17:02:11 2023 +0800

    [FLINK-30607][python] Support NoneType for Pandas UDF
    
    This closes #21639.
---
 flink-python/pyflink/fn_execution/coders.py        |   4 +-
 .../pyflink/fn_execution/flink_fn_execution_pb2.py | 122 ++++++++++-----------
 .../pyflink/proto/flink-fn-execution.proto         |   1 +
 .../pyflink/table/tests/test_pandas_udf.py         |   2 +-
 flink-python/pyflink/table/types.py                |   4 +
 .../flink/table/runtime/arrow/ArrowUtils.java      |  15 +++
 .../arrow/vectors/ArrowNullColumnVector.java       |  36 ++++++
 .../table/runtime/arrow/writers/NullWriter.java    |  35 ++++++
 8 files changed, 156 insertions(+), 63 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 190efb5dea6..6cdd6bfaa45 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -30,7 +30,7 @@ from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, D
 from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
     FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
     LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType, \
-    BinaryType
+    BinaryType, NullType
 
 try:
     from pyflink.fn_execution import coder_impl_fast as coder_impl
@@ -155,6 +155,8 @@ class LengthPrefixBaseCoder(ABC):
             return MapType(cls._to_data_type(field_type.map_info.key_type),
                            cls._to_data_type(field_type.map_info.value_type),
                            field_type.nullable)
+        elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.NULL:
+            return NullType()
         else:
             raise ValueError("field_type %s is not supported." % field_type)
 
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index db4a2dc11c9..95ca119137d 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -31,7 +31,7 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"\xa8\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 \x03(\x0b\x32\'.org.ap [...]
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"\xa8\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 \x03(\x0b\x32\'.org.ap [...]
 
 
 
@@ -462,7 +462,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _USERDEFINEDAGGREGATEFUNCTIONS._serialized_start=2270
   _USERDEFINEDAGGREGATEFUNCTIONS._serialized_end=2804
   _SCHEMA._serialized_start=2807
-  _SCHEMA._serialized_end=4835
+  _SCHEMA._serialized_end=4845
   _SCHEMA_MAPINFO._serialized_start=2882
   _SCHEMA_MAPINFO._serialized_end=3033
   _SCHEMA_TIMEINFO._serialized_start=3035
@@ -488,63 +488,63 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _SCHEMA_FIELD._serialized_start=4435
   _SCHEMA_FIELD._serialized_end=4543
   _SCHEMA_TYPENAME._serialized_start=4546
-  _SCHEMA_TYPENAME._serialized_end=4835
-  _TYPEINFO._serialized_start=4838
-  _TYPEINFO._serialized_end=6185
-  _TYPEINFO_MAPTYPEINFO._serialized_start=5332
-  _TYPEINFO_MAPTYPEINFO._serialized_end=5471
-  _TYPEINFO_ROWTYPEINFO._serialized_start=5474
-  _TYPEINFO_ROWTYPEINFO._serialized_end=5658
-  _TYPEINFO_ROWTYPEINFO_FIELD._serialized_start=5567
-  _TYPEINFO_ROWTYPEINFO_FIELD._serialized_end=5658
-  _TYPEINFO_TUPLETYPEINFO._serialized_start=5660
-  _TYPEINFO_TUPLETYPEINFO._serialized_end=5740
-  _TYPEINFO_AVROTYPEINFO._serialized_start=5742
-  _TYPEINFO_AVROTYPEINFO._serialized_end=5772
-  _TYPEINFO_TYPENAME._serialized_start=5775
-  _TYPEINFO_TYPENAME._serialized_end=6172
-  _USERDEFINEDDATASTREAMFUNCTION._serialized_start=6188
-  _USERDEFINEDDATASTREAMFUNCTION._serialized_end=7239
-  _USERDEFINEDDATASTREAMFUNCTION_JOBPARAMETER._serialized_start=6682
-  _USERDEFINEDDATASTREAMFUNCTION_JOBPARAMETER._serialized_end=6724
-  _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT._serialized_start=6727
-  _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT._serialized_end=7063
-  _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE._serialized_start=7066
-  _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE._serialized_end=7239
-  _STATEDESCRIPTOR._serialized_start=7242
-  _STATEDESCRIPTOR._serialized_end=9134
-  _STATEDESCRIPTOR_STATETTLCONFIG._serialized_start=7374
-  _STATEDESCRIPTOR_STATETTLCONFIG._serialized_end=9134
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES._serialized_start=7845
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES._serialized_end=8943
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY._serialized_start=8023
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY._serialized_end=8111
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY._serialized_start=8113
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY._serialized_end=8188
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY._serialized_start=8191
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY._serialized_end=8799
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES._serialized_start=8801
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES._serialized_end=8899
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY._serialized_start=8901
-  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY._serialized_end=8943
-  _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE._serialized_start=8945
-  _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE._serialized_end=9013
-  _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY._serialized_start=9015
-  _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY._serialized_end=9089
-  _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC._serialized_start=9091
-  _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC._serialized_end=9134
-  _CODERINFODESCRIPTOR._serialized_start=9137
-  _CODERINFODESCRIPTOR._serialized_end=10146
-  _CODERINFODESCRIPTOR_FLATTENROWTYPE._serialized_start=9730
-  _CODERINFODESCRIPTOR_FLATTENROWTYPE._serialized_end=9804
-  _CODERINFODESCRIPTOR_ROWTYPE._serialized_start=9806
-  _CODERINFODESCRIPTOR_ROWTYPE._serialized_end=9873
-  _CODERINFODESCRIPTOR_ARROWTYPE._serialized_start=9875
-  _CODERINFODESCRIPTOR_ARROWTYPE._serialized_end=9944
-  _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE._serialized_start=9946
-  _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE._serialized_end=10025
-  _CODERINFODESCRIPTOR_RAWTYPE._serialized_start=10027
-  _CODERINFODESCRIPTOR_RAWTYPE._serialized_end=10099
-  _CODERINFODESCRIPTOR_MODE._serialized_start=10101
-  _CODERINFODESCRIPTOR_MODE._serialized_end=10133
+  _SCHEMA_TYPENAME._serialized_end=4845
+  _TYPEINFO._serialized_start=4848
+  _TYPEINFO._serialized_end=6195
+  _TYPEINFO_MAPTYPEINFO._serialized_start=5342
+  _TYPEINFO_MAPTYPEINFO._serialized_end=5481
+  _TYPEINFO_ROWTYPEINFO._serialized_start=5484
+  _TYPEINFO_ROWTYPEINFO._serialized_end=5668
+  _TYPEINFO_ROWTYPEINFO_FIELD._serialized_start=5577
+  _TYPEINFO_ROWTYPEINFO_FIELD._serialized_end=5668
+  _TYPEINFO_TUPLETYPEINFO._serialized_start=5670
+  _TYPEINFO_TUPLETYPEINFO._serialized_end=5750
+  _TYPEINFO_AVROTYPEINFO._serialized_start=5752
+  _TYPEINFO_AVROTYPEINFO._serialized_end=5782
+  _TYPEINFO_TYPENAME._serialized_start=5785
+  _TYPEINFO_TYPENAME._serialized_end=6182
+  _USERDEFINEDDATASTREAMFUNCTION._serialized_start=6198
+  _USERDEFINEDDATASTREAMFUNCTION._serialized_end=7249
+  _USERDEFINEDDATASTREAMFUNCTION_JOBPARAMETER._serialized_start=6692
+  _USERDEFINEDDATASTREAMFUNCTION_JOBPARAMETER._serialized_end=6734
+  _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT._serialized_start=6737
+  _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT._serialized_end=7073
+  _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE._serialized_start=7076
+  _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE._serialized_end=7249
+  _STATEDESCRIPTOR._serialized_start=7252
+  _STATEDESCRIPTOR._serialized_end=9144
+  _STATEDESCRIPTOR_STATETTLCONFIG._serialized_start=7384
+  _STATEDESCRIPTOR_STATETTLCONFIG._serialized_end=9144
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES._serialized_start=7855
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES._serialized_end=8953
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY._serialized_start=8033
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY._serialized_end=8121
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY._serialized_start=8123
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY._serialized_end=8198
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY._serialized_start=8201
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY._serialized_end=8809
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES._serialized_start=8811
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES._serialized_end=8909
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY._serialized_start=8911
+  _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY._serialized_end=8953
+  _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE._serialized_start=8955
+  _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE._serialized_end=9023
+  _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY._serialized_start=9025
+  _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY._serialized_end=9099
+  _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC._serialized_start=9101
+  _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC._serialized_end=9144
+  _CODERINFODESCRIPTOR._serialized_start=9147
+  _CODERINFODESCRIPTOR._serialized_end=10156
+  _CODERINFODESCRIPTOR_FLATTENROWTYPE._serialized_start=9740
+  _CODERINFODESCRIPTOR_FLATTENROWTYPE._serialized_end=9814
+  _CODERINFODESCRIPTOR_ROWTYPE._serialized_start=9816
+  _CODERINFODESCRIPTOR_ROWTYPE._serialized_end=9883
+  _CODERINFODESCRIPTOR_ARROWTYPE._serialized_start=9885
+  _CODERINFODESCRIPTOR_ARROWTYPE._serialized_end=9954
+  _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE._serialized_start=9956
+  _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE._serialized_end=10035
+  _CODERINFODESCRIPTOR_RAWTYPE._serialized_start=10037
+  _CODERINFODESCRIPTOR_RAWTYPE._serialized_end=10109
+  _CODERINFODESCRIPTOR_MODE._serialized_start=10111
+  _CODERINFODESCRIPTOR_MODE._serialized_end=10143
 # @@protoc_insertion_point(module_scope)
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto
index 613a7cd888b..db27aed7409 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -208,6 +208,7 @@ message Schema {
     MULTISET = 18;
     LOCAL_ZONED_TIMESTAMP = 19;
     ZONED_TIMESTAMP = 20;
+    NULL = 21;
   }
 
   message MapInfo {
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index 0e4d9d9a8ab..7a6b36f5b17 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -313,7 +313,7 @@ class PandasUDFITTests(object):
             nested_array_func(t.t),
             row_func(t.u),
             map_func(t.v),
-            binary_func(t.k)) \
+            binary_func(t.w)) \
             .execute_insert("Results_test_all_data_types").wait()
         actual = source_sink_utils.results()
         self.assert_equals(
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index c8227f66042..856942e5936 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2256,6 +2256,8 @@ def from_arrow_type(arrow_type, nullable: bool = True) -> DataType:
                             str(arrow_type))
         return RowType([RowField(field.name, from_arrow_type(field.type, field.nullable))
                         for field in arrow_type])
+    elif types.is_null(arrow_type):
+        return NullType()
     else:
         raise TypeError("Unsupported data type to convert to Arrow type: " + str(dt))
 
@@ -2322,6 +2324,8 @@ def to_arrow_type(data_type: DataType):
         fields = [pa.field(field.name, to_arrow_type(field.data_type), field.data_type._nullable)
                   for field in data_type]
         return pa.struct(fields)
+    elif isinstance(data_type, NullType):
+        return pa.null()
     else:
         raise ValueError("field_type %s is not supported." % data_type)
 
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
index bdd691b5c9d..d14a3952607 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
@@ -43,6 +43,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowMapColumnVector;
+import org.apache.flink.table.runtime.arrow.vectors.ArrowNullColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowRowColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
@@ -61,6 +62,7 @@ import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
 import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
 import org.apache.flink.table.runtime.arrow.writers.IntWriter;
 import org.apache.flink.table.runtime.arrow.writers.MapWriter;
+import org.apache.flink.table.runtime.arrow.writers.NullWriter;
 import org.apache.flink.table.runtime.arrow.writers.RowWriter;
 import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
 import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
@@ -83,6 +85,7 @@ import org.apache.flink.table.types.logical.LegacyTypeInformationType;
 import org.apache.flink.table.types.logical.LocalZonedTimestampType;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.MapType;
+import org.apache.flink.table.types.logical.NullType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.SmallIntType;
 import org.apache.flink.table.types.logical.TimeType;
@@ -110,6 +113,7 @@ import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.NullVector;
 import org.apache.arrow.vector.SmallIntVector;
 import org.apache.arrow.vector.TimeMicroVector;
 import org.apache.arrow.vector.TimeMilliVector;
@@ -310,6 +314,8 @@ public final class ArrowUtils {
                                 ((StructVector) vector).getVectorById(i), rowType.getTypeAt(i));
             }
             return RowWriter.forRow((StructVector) vector, fieldsWriters);
+        } else if (vector instanceof NullVector) {
+            return new NullWriter<>((NullVector) vector);
         } else {
             throw new UnsupportedOperationException(
                     String.format("Unsupported type %s.", fieldType));
@@ -385,6 +391,8 @@ public final class ArrowUtils {
                                 ((StructVector) vector).getVectorById(i), rowType.getTypeAt(i));
             }
             return RowWriter.forArray((StructVector) vector, fieldsWriters);
+        } else if (vector instanceof NullVector) {
+            return new NullWriter<>((NullVector) vector);
         } else {
             throw new UnsupportedOperationException(
                     String.format("Unsupported type %s.", fieldType));
@@ -459,6 +467,8 @@ public final class ArrowUtils {
                                 structVector.getVectorById(i), ((RowType) fieldType).getTypeAt(i));
             }
             return new ArrowRowColumnVector(structVector, fieldColumns);
+        } else if (vector instanceof NullVector) {
+            return ArrowNullColumnVector.INSTANCE;
         } else {
             throw new UnsupportedOperationException(
                     String.format("Unsupported type %s.", fieldType));
@@ -801,6 +811,11 @@ public final class ArrowUtils {
             return new ArrowType.Map(false);
         }
 
+        @Override
+        public ArrowType visit(NullType nullType) {
+            return ArrowType.Null.INSTANCE;
+        }
+
         @Override
         protected ArrowType defaultMethod(LogicalType logicalType) {
             if (logicalType instanceof LegacyTypeInformationType) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowNullColumnVector.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowNullColumnVector.java
new file mode 100644
index 00000000000..d01c89c423a
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowNullColumnVector.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.vectors;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.columnar.vector.ColumnVector;
+
+/** Arrow column vector for Null. */
+@Internal
+public final class ArrowNullColumnVector implements ColumnVector {
+
+    public static final ArrowNullColumnVector INSTANCE = new ArrowNullColumnVector();
+
+    private ArrowNullColumnVector() {}
+
+    @Override
+    public boolean isNullAt(int i) {
+        return true;
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/NullWriter.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/NullWriter.java
new file mode 100644
index 00000000000..76f6d836b89
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/NullWriter.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.writers;
+
+import org.apache.flink.annotation.Internal;
+
+import org.apache.arrow.vector.NullVector;
+
+/** {@link ArrowFieldWriter} for Null. */
+@Internal
+public class NullWriter<T> extends ArrowFieldWriter<T> {
+
+    public NullWriter(NullVector nullVector) {
+        super(nullVector);
+    }
+
+    @Override
+    public void doWrite(T in, int ordinal) {}
+}


[flink] 01/03: [FLINK-30607][python] Support MapType for Pandas UDF

Posted by di...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b781a13dd615e8d131defe37ca9e550416c10595
Author: Dian Fu <di...@apache.org>
AuthorDate: Tue Jan 10 15:29:55 2023 +0800

    [FLINK-30607][python] Support MapType for Pandas UDF
    
    This closes #21639.
---
 flink-python/pyflink/fn_execution/coders.py        |   6 +-
 .../pyflink/table/tests/test_pandas_udf.py         |  48 ++++++--
 flink-python/pyflink/table/types.py                |   6 +
 .../flink/table/runtime/arrow/ArrowUtils.java      |  54 +++++++++
 .../arrow/vectors/ArrowMapColumnVector.java        |  60 ++++++++++
 .../table/runtime/arrow/writers/MapWriter.java     | 130 +++++++++++++++++++++
 6 files changed, 294 insertions(+), 10 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 1527adb4e39..704d278666b 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -29,7 +29,7 @@ from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, D
     ExternalTypeInfo
 from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
     FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
-    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType
+    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType
 
 try:
     from pyflink.fn_execution import coder_impl_fast as coder_impl
@@ -148,6 +148,10 @@ class LengthPrefixBaseCoder(ABC):
             return RowType(
                 [RowField(f.name, cls._to_data_type(f.type), f.description)
                  for f in field_type.row_schema.fields], field_type.nullable)
+        elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.MAP:
+            return MapType(cls._to_data_type(field_type.map_info.key_type),
+                           cls._to_data_type(field_type.map_info.value_type),
+                           field_type.nullable)
         else:
             raise ValueError("field_type %s is not supported." % field_type)
 
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index aa0dd8a9596..46d9560b351 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -213,12 +213,38 @@ class PandasUDFITTests(object):
                 'row_param.f4 of wrong type %s !' % type(row_param.f4[0])
             return row_param
 
+        map_type = DataTypes.MAP(DataTypes.STRING(False), DataTypes.STRING())
+
+        @udf(result_type=map_type, func_type="pandas")
+        def map_func(map_param):
+            assert isinstance(map_param, pd.Series)
+            return map_param
+
         sink_table_ddl = """
-        CREATE TABLE Results_test_all_data_types(
-        a TINYINT, b SMALLINT, c INT, d BIGINT, e BOOLEAN, f BOOLEAN, g FLOAT, h DOUBLE, i STRING,
-        j StRING, k BYTES, l DECIMAL(38, 18), m DECIMAL(38, 18), n DATE, o TIME, p TIMESTAMP(3),
-        q ARRAY<STRING>, r ARRAY<TIMESTAMP(3)>, s ARRAY<INT>, t ARRAY<STRING>,
-        u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>) WITH ('connector'='test-sink')
+            CREATE TABLE Results_test_all_data_types(
+                a TINYINT,
+                b SMALLINT,
+                c INT,
+                d BIGINT,
+                e BOOLEAN,
+                f BOOLEAN,
+                g FLOAT,
+                h DOUBLE,
+                i STRING,
+                j StRING,
+                k BYTES,
+                l DECIMAL(38, 18),
+                m DECIMAL(38, 18),
+                n DATE,
+                o TIME,
+                p TIMESTAMP(3),
+                q ARRAY<STRING>,
+                r ARRAY<TIMESTAMP(3)>,
+                s ARRAY<INT>,
+                t ARRAY<STRING>,
+                u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>,
+                v MAP<STRING, STRING>
+            ) WITH ('connector'='test-sink')
         """
         self.t_env.execute_sql(sink_table_ddl)
 
@@ -228,7 +254,8 @@ class PandasUDFITTests(object):
               decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
               datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1),
               timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2],
-              [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]))],
+              [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]),
+              {"1": "hello", "2": "world"})],
             DataTypes.ROW(
                 [DataTypes.FIELD("a", DataTypes.TINYINT()),
                  DataTypes.FIELD("b", DataTypes.SMALLINT()),
@@ -250,7 +277,8 @@ class PandasUDFITTests(object):
                  DataTypes.FIELD("r", DataTypes.ARRAY(DataTypes.TIMESTAMP(3))),
                  DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())),
                  DataTypes.FIELD("t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))),
-                 DataTypes.FIELD("u", row_type)]))
+                 DataTypes.FIELD("u", row_type),
+                 DataTypes.FIELD("v", map_type)]))
 
         t.select(
             tinyint_func(t.a),
@@ -273,7 +301,8 @@ class PandasUDFITTests(object):
             array_timestamp_func(t.r),
             array_int_func(t.s),
             nested_array_func(t.t),
-            row_func(t.u)) \
+            row_func(t.u),
+            map_func(t.v)) \
             .execute_insert("Results_test_all_data_types").wait()
         actual = source_sink_utils.results()
         self.assert_equals(
@@ -282,7 +311,8 @@ class PandasUDFITTests(object):
              "[102, 108, 105, 110, 107], 1000000000000000000.050000000000000000, "
              "1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, "
              "1970-01-02T00:00:00.123, [hello, 中文, null], [1970-01-02T00:00:00.123], "
-             "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]]]"])
+             "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]], "
+             "{1=hello, 2=world}]"])
 
     def test_invalid_pandas_udf(self):
 
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index b3aac1d3525..83d54629dad 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2244,6 +2244,10 @@ def from_arrow_type(arrow_type, nullable: bool = True) -> DataType:
             return TimestampType(6, nullable)
         else:
             return TimestampType(9, nullable)
+    elif types.is_map(arrow_type):
+        return MapType(from_arrow_type(arrow_type.key_type),
+                       from_arrow_type(arrow_type.item_type),
+                       nullable)
     elif types.is_list(arrow_type):
         return ArrayType(from_arrow_type(arrow_type.value_type), nullable)
     elif types.is_struct(arrow_type):
@@ -2301,6 +2305,8 @@ def to_arrow_type(data_type: DataType):
             return pa.timestamp('us')
         else:
             return pa.timestamp('ns')
+    elif isinstance(data_type, MapType):
+        return pa.map_(to_arrow_type(data_type.key_type), to_arrow_type(data_type.value_type))
     elif isinstance(data_type, ArrayType):
         if type(data_type.element_type) in [LocalZonedTimestampType, RowType]:
             raise ValueError("%s is not supported to be used as the element type of ArrayType." %
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
index 28f0a74e51d..2d4e09e4553 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
@@ -41,6 +41,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDecimalColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
+import org.apache.flink.table.runtime.arrow.vectors.ArrowMapColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowRowColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
@@ -57,6 +58,7 @@ import org.apache.flink.table.runtime.arrow.writers.DecimalWriter;
 import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
 import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
 import org.apache.flink.table.runtime.arrow.writers.IntWriter;
+import org.apache.flink.table.runtime.arrow.writers.MapWriter;
 import org.apache.flink.table.runtime.arrow.writers.RowWriter;
 import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
 import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
@@ -77,6 +79,7 @@ import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.LegacyTypeInformationType;
 import org.apache.flink.table.types.logical.LocalZonedTimestampType;
 import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.SmallIntType;
 import org.apache.flink.table.types.logical.TimeType;
@@ -88,6 +91,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor;
 import org.apache.flink.table.types.utils.TypeConversions;
 import org.apache.flink.types.Row;
 import org.apache.flink.types.RowKind;
+import org.apache.flink.util.Preconditions;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.LinkedHashMultiset;
 
@@ -114,6 +118,7 @@ import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.ipc.ArrowStreamWriter;
 import org.apache.arrow.vector.ipc.ReadChannel;
@@ -139,6 +144,7 @@ import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
 import java.nio.channels.ReadableByteChannel;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
@@ -200,6 +206,18 @@ public final class ArrowUtils {
             for (RowType.RowField field : rowType.getFields()) {
                 children.add(toArrowField(field.getName(), field.getType()));
             }
+        } else if (logicalType instanceof MapType) {
+            MapType mapType = (MapType) logicalType;
+            Preconditions.checkArgument(
+                    !mapType.getKeyType().isNullable(), "Map key type should be non-nullable");
+            children =
+                    Collections.singletonList(
+                            new Field(
+                                    "items",
+                                    new FieldType(false, ArrowType.Struct.INSTANCE, null),
+                                    Arrays.asList(
+                                            toArrowField("key", mapType.getKeyType()),
+                                            toArrowField("value", mapType.getValueType()))));
         }
         return new Field(fieldName, fieldType, children);
     }
@@ -259,6 +277,17 @@ public final class ArrowUtils {
                 precision = ((TimestampType) fieldType).getPrecision();
             }
             return TimestampWriter.forRow(vector, precision);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return MapWriter.forRow(
+                    mapVector,
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -321,6 +350,17 @@ public final class ArrowUtils {
                 precision = ((TimestampType) fieldType).getPrecision();
             }
             return TimestampWriter.forArray(vector, precision);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return MapWriter.forArray(
+                    mapVector,
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -385,6 +425,15 @@ public final class ArrowUtils {
         } else if (vector instanceof TimeStampVector
                 && ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
             return new ArrowTimestampColumnVector(vector);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return new ArrowMapColumnVector(
+                    mapVector,
+                    createColumnVector(structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createColumnVector(structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             return new ArrowArrayColumnVector(
@@ -732,6 +781,11 @@ public final class ArrowUtils {
             return ArrowType.Struct.INSTANCE;
         }
 
+        @Override
+        public ArrowType visit(MapType mapType) {
+            return new ArrowType.Map(false);
+        }
+
         @Override
         protected ArrowType defaultMethod(LogicalType logicalType) {
             if (logicalType instanceof LegacyTypeInformationType) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
new file mode 100644
index 00000000000..d269c7e28a3
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.vectors;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.columnar.ColumnarMapData;
+import org.apache.flink.table.data.columnar.vector.ColumnVector;
+import org.apache.flink.table.data.columnar.vector.MapColumnVector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+
+/** Arrow column vector for Map. */
+@Internal
+public final class ArrowMapColumnVector implements MapColumnVector {
+
+    /** Container which is used to store the map values of a column to read. */
+    private final MapVector mapVector;
+
+    private final ColumnVector keyVector;
+
+    private final ColumnVector valueVector;
+
+    public ArrowMapColumnVector(
+            MapVector mapVector, ColumnVector keyVector, ColumnVector valueVector) {
+        this.mapVector = Preconditions.checkNotNull(mapVector);
+        this.keyVector = Preconditions.checkNotNull(keyVector);
+        this.valueVector = Preconditions.checkNotNull(valueVector);
+    }
+
+    @Override
+    public MapData getMap(int i) {
+        int index = i * MapVector.OFFSET_WIDTH;
+        int offset = mapVector.getOffsetBuffer().getInt(index);
+        int numElements = mapVector.getInnerValueCountAt(i);
+        return new ColumnarMapData(keyVector, valueVector, offset, numElements);
+    }
+
+    @Override
+    public boolean isNullAt(int i) {
+        return mapVector.isNull(i);
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
new file mode 100644
index 00000000000..c2c960c3328
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.writers;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+
+/** {@link ArrowFieldWriter} for Map. */
+@Internal
+public abstract class MapWriter<T> extends ArrowFieldWriter<T> {
+
+    public static MapWriter<RowData> forRow(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        return new MapWriterForRow(mapVector, keyWriter, valueWriter);
+    }
+
+    public static MapWriter<ArrayData> forArray(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        return new MapWriterForArray(mapVector, keyWriter, valueWriter);
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    private final ArrowFieldWriter<ArrayData> keyWriter;
+
+    private final ArrowFieldWriter<ArrayData> valueWriter;
+
+    private MapWriter(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        super(mapVector);
+        this.keyWriter = Preconditions.checkNotNull(keyWriter);
+        this.valueWriter = Preconditions.checkNotNull(valueWriter);
+    }
+
+    abstract boolean isNullAt(T in, int ordinal);
+
+    abstract MapData readMap(T in, int ordinal);
+
+    @Override
+    public void doWrite(T in, int ordinal) {
+        if (!isNullAt(in, ordinal)) {
+            ((MapVector) getValueVector()).startNewValue(getCount());
+
+            StructVector structVector =
+                    (StructVector) ((MapVector) getValueVector()).getDataVector();
+            MapData map = readMap(in, ordinal);
+            ArrayData keys = map.keyArray();
+            ArrayData values = map.valueArray();
+            for (int i = 0; i < map.size(); i++) {
+                structVector.setIndexDefined(keyWriter.getCount());
+                keyWriter.write(keys, i);
+                valueWriter.write(values, i);
+            }
+
+            ((MapVector) getValueVector()).endValue(getCount(), map.size());
+        }
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    /** {@link MapWriter} for {@link RowData} input. */
+    public static final class MapWriterForRow extends MapWriter<RowData> {
+
+        private MapWriterForRow(
+                MapVector mapVector,
+                ArrowFieldWriter<ArrayData> keyWriter,
+                ArrowFieldWriter<ArrayData> valueWriter) {
+            super(mapVector, keyWriter, valueWriter);
+        }
+
+        @Override
+        boolean isNullAt(RowData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        MapData readMap(RowData in, int ordinal) {
+            return in.getMap(ordinal);
+        }
+    }
+
+    /** {@link MapWriter} for {@link ArrayData} input. */
+    public static final class MapWriterForArray extends MapWriter<ArrayData> {
+
+        private MapWriterForArray(
+                MapVector mapVector,
+                ArrowFieldWriter<ArrayData> keyWriter,
+                ArrowFieldWriter<ArrayData> valueWriter) {
+            super(mapVector, keyWriter, valueWriter);
+        }
+
+        @Override
+        boolean isNullAt(ArrayData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        MapData readMap(ArrayData in, int ordinal) {
+            return in.getMap(ordinal);
+        }
+    }
+}