You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2023/01/11 02:46:24 UTC
[flink] 01/03: [FLINK-30607][python] Support MapType for Pandas UDF
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit b781a13dd615e8d131defe37ca9e550416c10595
Author: Dian Fu <di...@apache.org>
AuthorDate: Tue Jan 10 15:29:55 2023 +0800
[FLINK-30607][python] Support MapType for Pandas UDF
This closes #21639.
---
flink-python/pyflink/fn_execution/coders.py | 6 +-
.../pyflink/table/tests/test_pandas_udf.py | 48 ++++++--
flink-python/pyflink/table/types.py | 6 +
.../flink/table/runtime/arrow/ArrowUtils.java | 54 +++++++++
.../arrow/vectors/ArrowMapColumnVector.java | 60 ++++++++++
.../table/runtime/arrow/writers/MapWriter.java | 130 +++++++++++++++++++++
6 files changed, 294 insertions(+), 10 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 1527adb4e39..704d278666b 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -29,7 +29,7 @@ from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, D
ExternalTypeInfo
from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
- LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType
+ LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType, MapType
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
@@ -148,6 +148,10 @@ class LengthPrefixBaseCoder(ABC):
return RowType(
[RowField(f.name, cls._to_data_type(f.type), f.description)
for f in field_type.row_schema.fields], field_type.nullable)
+ elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.MAP:
+ return MapType(cls._to_data_type(field_type.map_info.key_type),
+ cls._to_data_type(field_type.map_info.value_type),
+ field_type.nullable)
else:
raise ValueError("field_type %s is not supported." % field_type)
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index aa0dd8a9596..46d9560b351 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -213,12 +213,38 @@ class PandasUDFITTests(object):
'row_param.f4 of wrong type %s !' % type(row_param.f4[0])
return row_param
+ map_type = DataTypes.MAP(DataTypes.STRING(False), DataTypes.STRING())
+
+ @udf(result_type=map_type, func_type="pandas")
+ def map_func(map_param):
+ assert isinstance(map_param, pd.Series)
+ return map_param
+
sink_table_ddl = """
- CREATE TABLE Results_test_all_data_types(
- a TINYINT, b SMALLINT, c INT, d BIGINT, e BOOLEAN, f BOOLEAN, g FLOAT, h DOUBLE, i STRING,
- j StRING, k BYTES, l DECIMAL(38, 18), m DECIMAL(38, 18), n DATE, o TIME, p TIMESTAMP(3),
- q ARRAY<STRING>, r ARRAY<TIMESTAMP(3)>, s ARRAY<INT>, t ARRAY<STRING>,
- u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>) WITH ('connector'='test-sink')
+ CREATE TABLE Results_test_all_data_types(
+ a TINYINT,
+ b SMALLINT,
+ c INT,
+ d BIGINT,
+ e BOOLEAN,
+ f BOOLEAN,
+ g FLOAT,
+ h DOUBLE,
+ i STRING,
+ j StRING,
+ k BYTES,
+ l DECIMAL(38, 18),
+ m DECIMAL(38, 18),
+ n DATE,
+ o TIME,
+ p TIMESTAMP(3),
+ q ARRAY<STRING>,
+ r ARRAY<TIMESTAMP(3)>,
+ s ARRAY<INT>,
+ t ARRAY<STRING>,
+ u ROW<f1 INT, f2 STRING, f3 TIMESTAMP(3), f4 ARRAY<INT>>,
+ v MAP<STRING, STRING>
+ ) WITH ('connector'='test-sink')
"""
self.t_env.execute_sql(sink_table_ddl)
@@ -228,7 +254,8 @@ class PandasUDFITTests(object):
decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1),
timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2],
- [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]))],
+ [['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]),
+ {"1": "hello", "2": "world"})],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
@@ -250,7 +277,8 @@ class PandasUDFITTests(object):
DataTypes.FIELD("r", DataTypes.ARRAY(DataTypes.TIMESTAMP(3))),
DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())),
DataTypes.FIELD("t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))),
- DataTypes.FIELD("u", row_type)]))
+ DataTypes.FIELD("u", row_type),
+ DataTypes.FIELD("v", map_type)]))
t.select(
tinyint_func(t.a),
@@ -273,7 +301,8 @@ class PandasUDFITTests(object):
array_timestamp_func(t.r),
array_int_func(t.s),
nested_array_func(t.t),
- row_func(t.u)) \
+ row_func(t.u),
+ map_func(t.v)) \
.execute_insert("Results_test_all_data_types").wait()
actual = source_sink_utils.results()
self.assert_equals(
@@ -282,7 +311,8 @@ class PandasUDFITTests(object):
"[102, 108, 105, 110, 107], 1000000000000000000.050000000000000000, "
"1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, "
"1970-01-02T00:00:00.123, [hello, 中文, null], [1970-01-02T00:00:00.123], "
- "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]]]"])
+ "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02T00:00:00.123, [1, 2]], "
+ "{1=hello, 2=world}]"])
def test_invalid_pandas_udf(self):
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index b3aac1d3525..83d54629dad 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2244,6 +2244,10 @@ def from_arrow_type(arrow_type, nullable: bool = True) -> DataType:
return TimestampType(6, nullable)
else:
return TimestampType(9, nullable)
+ elif types.is_map(arrow_type):
+ return MapType(from_arrow_type(arrow_type.key_type),
+ from_arrow_type(arrow_type.item_type),
+ nullable)
elif types.is_list(arrow_type):
return ArrayType(from_arrow_type(arrow_type.value_type), nullable)
elif types.is_struct(arrow_type):
@@ -2301,6 +2305,8 @@ def to_arrow_type(data_type: DataType):
return pa.timestamp('us')
else:
return pa.timestamp('ns')
+ elif isinstance(data_type, MapType):
+ return pa.map_(to_arrow_type(data_type.key_type), to_arrow_type(data_type.value_type))
elif isinstance(data_type, ArrayType):
if type(data_type.element_type) in [LocalZonedTimestampType, RowType]:
raise ValueError("%s is not supported to be used as the element type of ArrayType." %
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
index 28f0a74e51d..2d4e09e4553 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/ArrowUtils.java
@@ -41,6 +41,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDecimalColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
+import org.apache.flink.table.runtime.arrow.vectors.ArrowMapColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowRowColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
@@ -57,6 +58,7 @@ import org.apache.flink.table.runtime.arrow.writers.DecimalWriter;
import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
+import org.apache.flink.table.runtime.arrow.writers.MapWriter;
import org.apache.flink.table.runtime.arrow.writers.RowWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
@@ -77,6 +79,7 @@ import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
@@ -88,6 +91,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
+import org.apache.flink.util.Preconditions;
import org.apache.flink.shaded.guava30.com.google.common.collect.LinkedHashMultiset;
@@ -114,6 +118,7 @@ import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
@@ -139,6 +144,7 @@ import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
@@ -200,6 +206,18 @@ public final class ArrowUtils {
for (RowType.RowField field : rowType.getFields()) {
children.add(toArrowField(field.getName(), field.getType()));
}
+ } else if (logicalType instanceof MapType) {
+ MapType mapType = (MapType) logicalType;
+ Preconditions.checkArgument(
+ !mapType.getKeyType().isNullable(), "Map key type should be non-nullable");
+ children =
+ Collections.singletonList(
+ new Field(
+ "items",
+ new FieldType(false, ArrowType.Struct.INSTANCE, null),
+ Arrays.asList(
+ toArrowField("key", mapType.getKeyType()),
+ toArrowField("value", mapType.getValueType()))));
}
return new Field(fieldName, fieldType, children);
}
@@ -259,6 +277,17 @@ public final class ArrowUtils {
precision = ((TimestampType) fieldType).getPrecision();
}
return TimestampWriter.forRow(vector, precision);
+ } else if (vector instanceof MapVector) {
+ MapVector mapVector = (MapVector) vector;
+ LogicalType keyType = ((MapType) fieldType).getKeyType();
+ LogicalType valueType = ((MapType) fieldType).getValueType();
+ StructVector structVector = (StructVector) mapVector.getDataVector();
+ return MapWriter.forRow(
+ mapVector,
+ createArrowFieldWriterForArray(
+ structVector.getChild(MapVector.KEY_NAME), keyType),
+ createArrowFieldWriterForArray(
+ structVector.getChild(MapVector.VALUE_NAME), valueType));
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -321,6 +350,17 @@ public final class ArrowUtils {
precision = ((TimestampType) fieldType).getPrecision();
}
return TimestampWriter.forArray(vector, precision);
+ } else if (vector instanceof MapVector) {
+ MapVector mapVector = (MapVector) vector;
+ LogicalType keyType = ((MapType) fieldType).getKeyType();
+ LogicalType valueType = ((MapType) fieldType).getValueType();
+ StructVector structVector = (StructVector) mapVector.getDataVector();
+ return MapWriter.forArray(
+ mapVector,
+ createArrowFieldWriterForArray(
+ structVector.getChild(MapVector.KEY_NAME), keyType),
+ createArrowFieldWriterForArray(
+ structVector.getChild(MapVector.VALUE_NAME), valueType));
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
LogicalType elementType = ((ArrayType) fieldType).getElementType();
@@ -385,6 +425,15 @@ public final class ArrowUtils {
} else if (vector instanceof TimeStampVector
&& ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
return new ArrowTimestampColumnVector(vector);
+ } else if (vector instanceof MapVector) {
+ MapVector mapVector = (MapVector) vector;
+ LogicalType keyType = ((MapType) fieldType).getKeyType();
+ LogicalType valueType = ((MapType) fieldType).getValueType();
+ StructVector structVector = (StructVector) mapVector.getDataVector();
+ return new ArrowMapColumnVector(
+ mapVector,
+ createColumnVector(structVector.getChild(MapVector.KEY_NAME), keyType),
+ createColumnVector(structVector.getChild(MapVector.VALUE_NAME), valueType));
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
return new ArrowArrayColumnVector(
@@ -732,6 +781,11 @@ public final class ArrowUtils {
return ArrowType.Struct.INSTANCE;
}
+ @Override
+ public ArrowType visit(MapType mapType) {
+ return new ArrowType.Map(false);
+ }
+
@Override
protected ArrowType defaultMethod(LogicalType logicalType) {
if (logicalType instanceof LegacyTypeInformationType) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
new file mode 100644
index 00000000000..d269c7e28a3
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/vectors/ArrowMapColumnVector.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.vectors;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.columnar.ColumnarMapData;
+import org.apache.flink.table.data.columnar.vector.ColumnVector;
+import org.apache.flink.table.data.columnar.vector.MapColumnVector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+
+/** Arrow column vector for Map. */
+@Internal
+public final class ArrowMapColumnVector implements MapColumnVector {
+
+ /** Container which is used to store the map values of a column to read. */
+ private final MapVector mapVector;
+
+ private final ColumnVector keyVector;
+
+ private final ColumnVector valueVector;
+
+ public ArrowMapColumnVector(
+ MapVector mapVector, ColumnVector keyVector, ColumnVector valueVector) {
+ this.mapVector = Preconditions.checkNotNull(mapVector);
+ this.keyVector = Preconditions.checkNotNull(keyVector);
+ this.valueVector = Preconditions.checkNotNull(valueVector);
+ }
+
+ @Override
+ public MapData getMap(int i) {
+ int index = i * MapVector.OFFSET_WIDTH;
+ int offset = mapVector.getOffsetBuffer().getInt(index);
+ int numElements = mapVector.getInnerValueCountAt(i);
+ return new ColumnarMapData(keyVector, valueVector, offset, numElements);
+ }
+
+ @Override
+ public boolean isNullAt(int i) {
+ return mapVector.isNull(i);
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
new file mode 100644
index 00000000000..c2c960c3328
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.arrow.writers;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+
+/** {@link ArrowFieldWriter} for Map. */
+@Internal
+public abstract class MapWriter<T> extends ArrowFieldWriter<T> {
+
+ public static MapWriter<RowData> forRow(
+ MapVector mapVector,
+ ArrowFieldWriter<ArrayData> keyWriter,
+ ArrowFieldWriter<ArrayData> valueWriter) {
+ return new MapWriterForRow(mapVector, keyWriter, valueWriter);
+ }
+
+ public static MapWriter<ArrayData> forArray(
+ MapVector mapVector,
+ ArrowFieldWriter<ArrayData> keyWriter,
+ ArrowFieldWriter<ArrayData> valueWriter) {
+ return new MapWriterForArray(mapVector, keyWriter, valueWriter);
+ }
+
+ // ------------------------------------------------------------------------------------------
+
+ private final ArrowFieldWriter<ArrayData> keyWriter;
+
+ private final ArrowFieldWriter<ArrayData> valueWriter;
+
+ private MapWriter(
+ MapVector mapVector,
+ ArrowFieldWriter<ArrayData> keyWriter,
+ ArrowFieldWriter<ArrayData> valueWriter) {
+ super(mapVector);
+ this.keyWriter = Preconditions.checkNotNull(keyWriter);
+ this.valueWriter = Preconditions.checkNotNull(valueWriter);
+ }
+
+ abstract boolean isNullAt(T in, int ordinal);
+
+ abstract MapData readMap(T in, int ordinal);
+
+ @Override
+ public void doWrite(T in, int ordinal) {
+ if (!isNullAt(in, ordinal)) {
+ ((MapVector) getValueVector()).startNewValue(getCount());
+
+ StructVector structVector =
+ (StructVector) ((MapVector) getValueVector()).getDataVector();
+ MapData map = readMap(in, ordinal);
+ ArrayData keys = map.keyArray();
+ ArrayData values = map.valueArray();
+ for (int i = 0; i < map.size(); i++) {
+ structVector.setIndexDefined(keyWriter.getCount());
+ keyWriter.write(keys, i);
+ valueWriter.write(values, i);
+ }
+
+ ((MapVector) getValueVector()).endValue(getCount(), map.size());
+ }
+ }
+
+ // ------------------------------------------------------------------------------------------
+
+ /** {@link MapWriter} for {@link RowData} input. */
+ public static final class MapWriterForRow extends MapWriter<RowData> {
+
+ private MapWriterForRow(
+ MapVector mapVector,
+ ArrowFieldWriter<ArrayData> keyWriter,
+ ArrowFieldWriter<ArrayData> valueWriter) {
+ super(mapVector, keyWriter, valueWriter);
+ }
+
+ @Override
+ boolean isNullAt(RowData in, int ordinal) {
+ return in.isNullAt(ordinal);
+ }
+
+ @Override
+ MapData readMap(RowData in, int ordinal) {
+ return in.getMap(ordinal);
+ }
+ }
+
+ /** {@link MapWriter} for {@link ArrayData} input. */
+ public static final class MapWriterForArray extends MapWriter<ArrayData> {
+
+ private MapWriterForArray(
+ MapVector mapVector,
+ ArrowFieldWriter<ArrayData> keyWriter,
+ ArrowFieldWriter<ArrayData> valueWriter) {
+ super(mapVector, keyWriter, valueWriter);
+ }
+
+ @Override
+ boolean isNullAt(ArrayData in, int ordinal) {
+ return in.isNullAt(ordinal);
+ }
+
+ @Override
+ MapData readMap(ArrayData in, int ordinal) {
+ return in.getMap(ordinal);
+ }
+ }
+}