You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2023/01/11 02:46:24 UTC

[flink] 01/03: [FLINK-30607][python] Support MapType for Pandas UDF

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b781a13dd615e8d131defe37ca9e550416c10595
Author: Dian Fu <di...@apache.org>
AuthorDate: Tue Jan 10 15:29:55 2023 +0800

    [FLINK-30607][python] Support MapType for Pandas UDF
    
    This closes #21639.
---
 flink-python/pyflink/fn_execution/coders.py        |   6 +-
 .../pyflink/table/tests/test_pandas_udf.py         |  48 ++++++--
 flink-python/pyflink/table/types.py                |   6 +
 .../flink/table/runtime/arrow/ArrowUtils.java      |  54 +++++++++
 .../arrow/vectors/ArrowMapColumnVector.java        |  60 ++++++++++
 .../table/runtime/arrow/writers/MapWriter.java     | 130 +++++++++++++++++++++
 6 files changed, 294 insertions(+), 10 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 1527adb4e39..704d278666b 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -29,7 +29,7 @@ from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, D
     ExternalTypeInfo
 from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
     FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
-    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType
+    LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType
 
 try:
     from pyflink.fn_execution import coder_impl_fast as coder_impl
@@ -148,6 +148,10 @@ class LengthPrefixBaseCoder(ABC):
             return RowType(
                 [RowField(f.name, cls._to_data_type(f.type), f.description)
                  for f in field_type.row_schema.fields], field_type.nullable)
+        elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.MAP:
+            return MapType(cls._to_data_type(field_type.map_info.key_type),
+                           cls._to_data_type(field_type.map_info.value_type),
+                           field_type.nullable)
         else:
             raise ValueError("field_type %s is not supported." % field_type)
 
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index aa0dd8a9596..46d9560b351 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -213,12 +213,38 @@ class PandasUDFITTests(object):
                 'row_param.f4 of wrong type %s !' % type(row_param.f4[0])
             return row_param
 
+        map_type = DataTypes.MAP(DataTypes.STRING(False), DataTypes.STRING())
+
+        @udf(result_type=map_type, func_type="pandas")
+        def map_func(map_param):
+            assert isinstance(map_param, pd.Series)
+            return map_param
+
         sink_table_ddl = """
-        CREATE TABLE Results_test_all_data_types(
-        a TINYINT, b SMALLINT, c INT, d BIGINT, e BOOLEAN, f BOOLEAN, g FLOAT, h DOUBLE, i STRING,
-        j StRING, k BYTES, l DECIMAL(38, 18), m DECIMAL(38, 18), n DATE, o TIME, p TIMESTAMP(3),
-        q ARRAY<STRING>, r ARRAY<TIMESTAMP(3)>, s ARRAY<INT>, t ARRAY<STRING>,
-        u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>) WITH ('connector'='test-sink')
+            CREATE TABLE Results_test_all_data_types(
+                a TINYINT,
+                b SMALLINT,
+                c INT,
+                d BIGINT,
+                e BOOLEAN,
+                f BOOLEAN,
+                g FLOAT,
+                h DOUBLE,
+                i STRING,
+                j StRING,
+                k BYTES,
+                l DECIMAL(38, 18),
+                m DECIMAL(38, 18),
+                n DATE,
+                o TIME,
+                p TIMESTAMP(3),
+                q ARRAY<STRING>,
+                r ARRAY<TIMESTAMP(3)>,
+                s ARRAY<INT>,
+                t ARRAY<STRING>,
+                u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>,
+                v MAP<STRING, STRING>
+            ) WITH ('connector'='test-sink')
         """
         self.t_env.execute_sql(sink_table_ddl)
 
@@ -228,7 +254,8 @@ class PandasUDFITTests(object):
               decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
               datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1),
               timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2],
-              [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]))],
+              [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]),
+              {"1": "hello", "2": "world"})],
             DataTypes.ROW(
                 [DataTypes.FIELD("a", DataTypes.TINYINT()),
                  DataTypes.FIELD("b", DataTypes.SMALLINT()),
@@ -250,7 +277,8 @@ class PandasUDFITTests(object):
                  DataTypes.FIELD("r", DataTypes.ARRAY(DataTypes.TIMESTAMP(3))),
                  DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())),
                  DataTypes.FIELD("t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))),
-                 DataTypes.FIELD("u", row_type)]))
+                 DataTypes.FIELD("u", row_type),
+                 DataTypes.FIELD("v", map_type)]))
 
         t.select(
             tinyint_func(t.a),
@@ -273,7 +301,8 @@ class PandasUDFITTests(object):
             array_timestamp_func(t.r),
             array_int_func(t.s),
             nested_array_func(t.t),
-            row_func(t.u)) \
+            row_func(t.u),
+            map_func(t.v)) \
             .execute_insert("Results_test_all_data_types").wait()
         actual = source_sink_utils.results()
         self.assert_equals(
@@ -282,7 +311,8 @@ class PandasUDFITTests(object):
              "[102, 108, 105, 110, 107], 1000000000000000000.050000000000000000, "
              "1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, "
              "1970-01-02T00:00:00.123, [hello, 中文, null], [1970-01-02T00:00:00.123], "
-             "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]]]"])
+             "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]], "
+             "{1=hello, 2=world}]"])
 
     def test_invalid_pandas_udf(self):
 
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index b3aac1d3525..83d54629dad 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2244,6 +2244,10 @@ def from_arrow_type(arrow_type, nullable: bool = True) -> DataType:
             return TimestampType(6, nullable)
         else:
             return TimestampType(9, nullable)
+    elif types.is_map(arrow_type):
+        return MapType(from_arrow_type(arrow_type.key_type),
+                       from_arrow_type(arrow_type.item_type),
+                       nullable)
     elif types.is_list(arrow_type):
         return ArrayType(from_arrow_type(arrow_type.value_type), nullable)
     elif types.is_struct(arrow_type):
@@ -2301,6 +2305,8 @@ def to_arrow_type(data_type: DataType):
             return pa.timestamp('us')
         else:
             return pa.timestamp('ns')
+    elif isinstance(data_type, MapType):
+        return pa.map_(to_arrow_type(data_type.key_type), to_arrow_type(data_type.value_type))
     elif isinstance(data_type, ArrayType):
         if type(data_type.element_type) in [LocalZonedTimestampType, RowType]:
             raise ValueError("%s is not supported to be used as the element type of ArrayType." %
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
index 28f0a74e51d..2d4e09e4553 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
@@ -41,6 +41,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDecimalColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
+import org.apache.flink.table.runtime.arrow.vectors.ArrowMapColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowRowColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
 import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
@@ -57,6 +58,7 @@ import org.apache.flink.table.runtime.arrow.writers.DecimalWriter;
 import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
 import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
 import org.apache.flink.table.runtime.arrow.writers.IntWriter;
+import org.apache.flink.table.runtime.arrow.writers.MapWriter;
 import org.apache.flink.table.runtime.arrow.writers.RowWriter;
 import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
 import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
@@ -77,6 +79,7 @@ import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.LegacyTypeInformationType;
 import org.apache.flink.table.types.logical.LocalZonedTimestampType;
 import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.SmallIntType;
 import org.apache.flink.table.types.logical.TimeType;
@@ -88,6 +91,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor;
 import org.apache.flink.table.types.utils.TypeConversions;
 import org.apache.flink.types.Row;
 import org.apache.flink.types.RowKind;
+import org.apache.flink.util.Preconditions;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.LinkedHashMultiset;
 
@@ -114,6 +118,7 @@ import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.ipc.ArrowStreamWriter;
 import org.apache.arrow.vector.ipc.ReadChannel;
@@ -139,6 +144,7 @@ import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
 import java.nio.channels.ReadableByteChannel;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
@@ -200,6 +206,18 @@ public final class ArrowUtils {
             for (RowType.RowField field : rowType.getFields()) {
                 children.add(toArrowField(field.getName(), field.getType()));
             }
+        } else if (logicalType instanceof MapType) {
+            MapType mapType = (MapType) logicalType;
+            Preconditions.checkArgument(
+                    !mapType.getKeyType().isNullable(), "Map key type should be non-nullable");
+            children =
+                    Collections.singletonList(
+                            new Field(
+                                    "items",
+                                    new FieldType(false, ArrowType.Struct.INSTANCE, null),
+                                    Arrays.asList(
+                                            toArrowField("key", mapType.getKeyType()),
+                                            toArrowField("value", mapType.getValueType()))));
         }
         return new Field(fieldName, fieldType, children);
     }
@@ -259,6 +277,17 @@ public final class ArrowUtils {
                 precision = ((TimestampType) fieldType).getPrecision();
             }
             return TimestampWriter.forRow(vector, precision);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return MapWriter.forRow(
+                    mapVector,
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -321,6 +350,17 @@ public final class ArrowUtils {
                 precision = ((TimestampType) fieldType).getPrecision();
             }
             return TimestampWriter.forArray(vector, precision);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return MapWriter.forArray(
+                    mapVector,
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createArrowFieldWriterForArray(
+                            structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -385,6 +425,15 @@ public final class ArrowUtils {
         } else if (vector instanceof TimeStampVector
                 && ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
             return new ArrowTimestampColumnVector(vector);
+        } else if (vector instanceof MapVector) {
+            MapVector mapVector = (MapVector) vector;
+            LogicalType keyType = ((MapType) fieldType).getKeyType();
+            LogicalType valueType = ((MapType) fieldType).getValueType();
+            StructVector structVector = (StructVector) mapVector.getDataVector();
+            return new ArrowMapColumnVector(
+                    mapVector,
+                    createColumnVector(structVector.getChild(MapVector.KEY_NAME), keyType),
+                    createColumnVector(structVector.getChild(MapVector.VALUE_NAME), valueType));
         } else if (vector instanceof ListVector) {
             ListVector listVector = (ListVector) vector;
             return new ArrowArrayColumnVector(
@@ -732,6 +781,11 @@ public final class ArrowUtils {
             return ArrowType.Struct.INSTANCE;
         }
 
+        @Override
+        public ArrowType visit(MapType mapType) {
+            return new ArrowType.Map(false);
+        }
+
         @Override
         protected ArrowType defaultMethod(LogicalType logicalType) {
             if (logicalType instanceof LegacyTypeInformationType) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
new file mode 100644
index 00000000000..d269c7e28a3
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.vectors;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.columnar.ColumnarMapData;
+import org.apache.flink.table.data.columnar.vector.ColumnVector;
+import org.apache.flink.table.data.columnar.vector.MapColumnVector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+
+/** Arrow column vector for Map. */
+@Internal
+public final class ArrowMapColumnVector implements MapColumnVector {
+
+    /** Container which is used to store the map values of a column to read. */
+    private final MapVector mapVector;
+
+    private final ColumnVector keyVector;
+
+    private final ColumnVector valueVector;
+
+    public ArrowMapColumnVector(
+            MapVector mapVector, ColumnVector keyVector, ColumnVector valueVector) {
+        this.mapVector = Preconditions.checkNotNull(mapVector);
+        this.keyVector = Preconditions.checkNotNull(keyVector);
+        this.valueVector = Preconditions.checkNotNull(valueVector);
+    }
+
+    @Override
+    public MapData getMap(int i) {
+        int index = i * MapVector.OFFSET_WIDTH;
+        int offset = mapVector.getOffsetBuffer().getInt(index);
+        int numElements = mapVector.getInnerValueCountAt(i);
+        return new ColumnarMapData(keyVector, valueVector, offset, numElements);
+    }
+
+    @Override
+    public boolean isNullAt(int i) {
+        return mapVector.isNull(i);
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
new file mode 100644
index 00000000000..c2c960c3328
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.writers;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+
+/** {@link ArrowFieldWriter} for Map. */
+@Internal
+public abstract class MapWriter<T> extends ArrowFieldWriter<T> {
+
+    public static MapWriter<RowData> forRow(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        return new MapWriterForRow(mapVector, keyWriter, valueWriter);
+    }
+
+    public static MapWriter<ArrayData> forArray(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        return new MapWriterForArray(mapVector, keyWriter, valueWriter);
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    private final ArrowFieldWriter<ArrayData> keyWriter;
+
+    private final ArrowFieldWriter<ArrayData> valueWriter;
+
+    private MapWriter(
+            MapVector mapVector,
+            ArrowFieldWriter<ArrayData> keyWriter,
+            ArrowFieldWriter<ArrayData> valueWriter) {
+        super(mapVector);
+        this.keyWriter = Preconditions.checkNotNull(keyWriter);
+        this.valueWriter = Preconditions.checkNotNull(valueWriter);
+    }
+
+    abstract boolean isNullAt(T in, int ordinal);
+
+    abstract MapData readMap(T in, int ordinal);
+
+    @Override
+    public void doWrite(T in, int ordinal) {
+        if (!isNullAt(in, ordinal)) {
+            ((MapVector) getValueVector()).startNewValue(getCount());
+
+            StructVector structVector =
+                    (StructVector) ((MapVector) getValueVector()).getDataVector();
+            MapData map = readMap(in, ordinal);
+            ArrayData keys = map.keyArray();
+            ArrayData values = map.valueArray();
+            for (int i = 0; i < map.size(); i++) {
+                structVector.setIndexDefined(keyWriter.getCount());
+                keyWriter.write(keys, i);
+                valueWriter.write(values, i);
+            }
+
+            ((MapVector) getValueVector()).endValue(getCount(), map.size());
+        }
+    }
+
+    // ------------------------------------------------------------------------------------------
+
+    /** {@link MapWriter} for {@link RowData} input. */
+    public static final class MapWriterForRow extends MapWriter<RowData> {
+
+        private MapWriterForRow(
+                MapVector mapVector,
+                ArrowFieldWriter<ArrayData> keyWriter,
+                ArrowFieldWriter<ArrayData> valueWriter) {
+            super(mapVector, keyWriter, valueWriter);
+        }
+
+        @Override
+        boolean isNullAt(RowData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        MapData readMap(RowData in, int ordinal) {
+            return in.getMap(ordinal);
+        }
+    }
+
+    /** {@link MapWriter} for {@link ArrayData} input. */
+    public static final class MapWriterForArray extends MapWriter<ArrayData> {
+
+        private MapWriterForArray(
+                MapVector mapVector,
+                ArrowFieldWriter<ArrayData> keyWriter,
+                ArrowFieldWriter<ArrayData> valueWriter) {
+            super(mapVector, keyWriter, valueWriter);
+        }
+
+        @Override
+        boolean isNullAt(ArrayData in, int ordinal) {
+            return in.isNullAt(ordinal);
+        }
+
+        @Override
+        MapData readMap(ArrayData in, int ordinal) {
+            return in.getMap(ordinal);
+        }
+    }
+}