You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2022/01/18 13:33:52 UTC

[flink] 02/02: [FLINK-17321][table] Add support casting of map to map and multiset to multiset

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit f5c99c6f2612bc2ae437e85f5c44cae50f631e4e
Author: Sergey Nuyanzin <sn...@gmail.com>
AuthorDate: Wed Dec 15 18:21:28 2021 +0100

    [FLINK-17321][table] Add support casting of map to map and multiset to multiset
    
    This closes #18287.
---
 .../functions/casting/CastRuleProvider.java        |   1 +
 .../MapToMapAndMultisetToMultisetCastRule.java     | 198 +++++++++++++++++++++
 .../planner/functions/CastFunctionITCase.java      |  45 ++++-
 .../planner/functions/casting/CastRulesTest.java   |  59 ++++++
 4 files changed, 297 insertions(+), 6 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java
index 5083519..961e81f 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java
@@ -81,6 +81,7 @@ public class CastRuleProvider {
                 .addRule(RawToBinaryCastRule.INSTANCE)
                 // Collection rules
                 .addRule(ArrayToArrayCastRule.INSTANCE)
+                .addRule(MapToMapAndMultisetToMultisetCastRule.INSTANCE)
                 .addRule(RowToRowCastRule.INSTANCE)
                 // Special rules
                 .addRule(CharVarCharTrimPadCastRule.INSTANCE)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java
new file mode 100644
index 0000000..89e0351
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.casting;
+
+import org.apache.flink.table.data.GenericMapData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.MapType;
+import org.apache.flink.table.types.logical.MultisetType;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.table.planner.codegen.CodeGenUtils.boxedTypeTermForType;
+import static org.apache.flink.table.planner.codegen.CodeGenUtils.className;
+import static org.apache.flink.table.planner.codegen.CodeGenUtils.newName;
+import static org.apache.flink.table.planner.codegen.CodeGenUtils.rowFieldReadAccess;
+import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall;
+import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall;
+
+/**
+ * {@link LogicalTypeRoot#MAP} to {@link LogicalTypeRoot#MAP} and {@link LogicalTypeRoot#MULTISET}
+ * to {@link LogicalTypeRoot#MULTISET} cast rule.
+ */
+class MapToMapAndMultisetToMultisetCastRule
+        extends AbstractNullAwareCodeGeneratorCastRule<MapData, MapData> {
+
+    static final MapToMapAndMultisetToMultisetCastRule INSTANCE =
+            new MapToMapAndMultisetToMultisetCastRule();
+
+    private MapToMapAndMultisetToMultisetCastRule() {
+        super(
+                CastRulePredicate.builder()
+                        .predicate(
+                                MapToMapAndMultisetToMultisetCastRule
+                                        ::isValidMapToMapOrMultisetToMultisetCasting)
+                        .build());
+    }
+
+    private static boolean isValidMapToMapOrMultisetToMultisetCasting(
+            LogicalType input, LogicalType target) {
+        return input.is(LogicalTypeRoot.MAP)
+                        && target.is(LogicalTypeRoot.MAP)
+                        && CastRuleProvider.resolve(
+                                        ((MapType) input).getKeyType(),
+                                        ((MapType) target).getKeyType())
+                                != null
+                        && CastRuleProvider.resolve(
+                                        ((MapType) input).getValueType(),
+                                        ((MapType) target).getValueType())
+                                != null
+                || input.is(LogicalTypeRoot.MULTISET)
+                        && target.is(LogicalTypeRoot.MULTISET)
+                        && CastRuleProvider.resolve(
+                                        ((MultisetType) input).getElementType(),
+                                        ((MultisetType) target).getElementType())
+                                != null;
+    }
+
+    /* Example generated code for MULTISET<INT> -> MULTISET<FLOAT>:
+    org.apache.flink.table.data.MapData _myInput = ((org.apache.flink.table.data.MapData)(_myInputObj));
+    boolean _myInputIsNull = _myInputObj == null;
+    boolean isNull$0;
+    org.apache.flink.table.data.MapData result$1;
+    float result$2;
+    isNull$0 = _myInputIsNull;
+    if (!isNull$0) {
+        java.util.Map map$838 = new java.util.HashMap();
+        for (int i$841 = 0; i$841 < _myInput.size(); i$841++) {
+            java.lang.Float key$839 = null;
+            java.lang.Integer value$840 = null;
+            if (!_myInput.keyArray().isNullAt(i$841)) {
+                result$2 = ((float)(_myInput.keyArray().getInt(i$841)));
+                key$839 = result$2;
+            }
+            value$840 = _myInput.valueArray().getInt(i$841);
+            map$838.put(key$839, value$840);
+        }
+        result$1 = new org.apache.flink.table.data.GenericMapData(map$838);
+        isNull$0 = result$1 == null;
+    } else {
+        result$1 = null;
+    }
+    return result$1;
+
+     */
+    @Override
+    protected String generateCodeBlockInternal(
+            CodeGeneratorCastRule.Context context,
+            String inputTerm,
+            String returnVariable,
+            LogicalType inputLogicalType,
+            LogicalType targetLogicalType) {
+        final LogicalType innerInputKeyType;
+        final LogicalType innerInputValueType;
+
+        final LogicalType innerTargetKeyType;
+        final LogicalType innerTargetValueType;
+        if (inputLogicalType.is(LogicalTypeRoot.MULTISET)) {
+            innerInputKeyType = ((MultisetType) inputLogicalType).getElementType();
+            innerInputValueType = new IntType(false);
+            innerTargetKeyType = ((MultisetType) targetLogicalType).getElementType();
+            innerTargetValueType = new IntType(false);
+        } else {
+            innerInputKeyType = ((MapType) inputLogicalType).getKeyType();
+            innerInputValueType = ((MapType) inputLogicalType).getValueType();
+            innerTargetKeyType = ((MapType) targetLogicalType).getKeyType();
+            innerTargetValueType = ((MapType) targetLogicalType).getValueType();
+        }
+
+        final String innerTargetKeyTypeTerm = boxedTypeTermForType(innerTargetKeyType);
+        final String innerTargetValueTypeTerm = boxedTypeTermForType(innerTargetValueType);
+        final String keyArrayTerm = methodCall(inputTerm, "keyArray");
+        final String valueArrayTerm = methodCall(inputTerm, "valueArray");
+        final String size = methodCall(inputTerm, "size");
+        final String map = newName("map");
+        final String key = newName("key");
+        final String value = newName("value");
+
+        return new CastRuleUtils.CodeWriter()
+                .declStmt(className(Map.class), map, constructorCall(HashMap.class))
+                .forStmt(
+                        size,
+                        (index, codeWriter) -> {
+                            final CastCodeBlock keyCodeBlock =
+                                    CastRuleProvider.generateAlwaysNonNullCodeBlock(
+                                            context,
+                                            rowFieldReadAccess(
+                                                    index, keyArrayTerm, innerInputKeyType),
+                                            innerInputKeyType,
+                                            innerTargetKeyType);
+                            assert keyCodeBlock != null;
+
+                            final CastCodeBlock valueCodeBlock =
+                                    CastRuleProvider.generateAlwaysNonNullCodeBlock(
+                                            context,
+                                            rowFieldReadAccess(
+                                                    index, valueArrayTerm, innerInputValueType),
+                                            innerInputValueType,
+                                            innerTargetValueType);
+                            assert valueCodeBlock != null;
+
+                            codeWriter
+                                    .declStmt(innerTargetKeyTypeTerm, key, null)
+                                    .declStmt(innerTargetValueTypeTerm, value, null);
+                            if (innerTargetKeyType.isNullable()) {
+                                codeWriter.ifStmt(
+                                        "!" + methodCall(keyArrayTerm, "isNullAt", index),
+                                        thenWriter ->
+                                                thenWriter
+                                                        .append(keyCodeBlock)
+                                                        .assignStmt(
+                                                                key, keyCodeBlock.getReturnTerm()));
+                            } else {
+                                codeWriter
+                                        .append(keyCodeBlock)
+                                        .assignStmt(key, keyCodeBlock.getReturnTerm());
+                            }
+
+                            if (inputLogicalType.is(LogicalTypeRoot.MAP)
+                                    && innerTargetValueType.isNullable()) {
+                                codeWriter.ifStmt(
+                                        "!" + methodCall(valueArrayTerm, "isNullAt", index),
+                                        thenWriter ->
+                                                thenWriter
+                                                        .append(valueCodeBlock)
+                                                        .assignStmt(
+                                                                value,
+                                                                valueCodeBlock.getReturnTerm()));
+                            } else {
+                                codeWriter
+                                        .append(valueCodeBlock)
+                                        .assignStmt(value, valueCodeBlock.getReturnTerm());
+                            }
+                            codeWriter.stmt(methodCall(map, "put", key, value));
+                        })
+                .assignStmt(returnVariable, constructorCall(GenericMapData.class, map))
+                .toString();
+    }
+}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
index a0449a8..ae9e595 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
@@ -40,10 +40,13 @@ import java.time.LocalTime;
 import java.time.Period;
 import java.time.ZoneId;
 import java.time.ZoneOffset;
+import java.util.AbstractMap;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import static org.apache.flink.table.api.DataTypes.ARRAY;
 import static org.apache.flink.table.api.DataTypes.BIGINT;
@@ -58,6 +61,7 @@ import static org.apache.flink.table.api.DataTypes.DOUBLE;
 import static org.apache.flink.table.api.DataTypes.FLOAT;
 import static org.apache.flink.table.api.DataTypes.INT;
 import static org.apache.flink.table.api.DataTypes.INTERVAL;
+import static org.apache.flink.table.api.DataTypes.MAP;
 import static org.apache.flink.table.api.DataTypes.MONTH;
 import static org.apache.flink.table.api.DataTypes.ROW;
 import static org.apache.flink.table.api.DataTypes.SECOND;
@@ -1142,14 +1146,27 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase {
 
     public static List<TestSpec> constructedTypes() {
         return Arrays.asList(
-                // https://issues.apache.org/jira/browse/FLINK-17321
-                // MULTISET
-                // MAP
+                CastTestSpecBuilder.testCastTo(MAP(STRING(), STRING()))
+                        .fromCase(MAP(FLOAT(), DOUBLE()), null, null)
+                        .fromCase(
+                                MAP(INT(), INT()),
+                                Collections.singletonMap(1, 2),
+                                Collections.singletonMap("1", "2"))
+                        .build(),
+                // https://issues.apache.org/jira/browse/FLINK-25567
+                // CastTestSpecBuilder.testCastTo(MULTISET(STRING()))
+                //        .fromCase(MULTISET(TIMESTAMP()), null, null)
+                //        .fromCase(
+                //                MULTISET(INT()),
+                //                map(entry(1, 2), entry(3, 4)),
+                //                map(entry("1", 2), entry("3", 4)))
+                //        .build(),
                 CastTestSpecBuilder.testCastTo(ARRAY(INT()))
                         .fromCase(ARRAY(INT()), null, null)
-                        // https://issues.apache.org/jira/browse/FLINK-17321
-                        // .fromCase(ARRAY(STRING()), new String[] {'1', '2', '3'}, new Integer[]
-                        // {1, 2, 3})
+                        .fromCase(
+                                ARRAY(STRING()),
+                                new String[] {"1", "2", "3"},
+                                new Integer[] {1, 2, 3})
                         // https://issues.apache.org/jira/browse/FLINK-24425 Cast from corresponding
                         // single type
                         // .fromCase(INT(), DEFAULT_POSITIVE_INT, new int[] {DEFAULT_POSITIVE_INT})
@@ -1314,4 +1331,20 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase {
     private static boolean isTimestampToNumeric(LogicalType srcType, LogicalType trgType) {
         return srcType.is(LogicalTypeFamily.TIMESTAMP) && trgType.is(LogicalTypeFamily.NUMERIC);
     }
+
+    private static <K, V> Map.Entry<K, V> entry(K k, V v) {
+        return new AbstractMap.SimpleImmutableEntry<>(k, v);
+    }
+
+    @SafeVarargs
+    private static <K, V> Map<K, V> map(Map.Entry<K, V>... entries) {
+        if (entries == null) {
+            return Collections.emptyMap();
+        }
+        Map<K, V> map = new HashMap<>();
+        for (Map.Entry<K, V> entry : entries) {
+            map.put(entry.getKey(), entry.getValue());
+        }
+        return map;
+    }
 }
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
index 9bb13a9..a0327c5 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
@@ -1131,6 +1131,56 @@ class CastRulesTest {
                                             new GenericArrayData(new Integer[] {3})
                                         }),
                                 NullPointerException.class),
+                CastTestSpecBuilder.testCastTo(MAP(DOUBLE().notNull(), DOUBLE().notNull()))
+                        .fromCase(
+                                MAP(INT().nullable(), INT().nullable()),
+                                mapData(entry(1, 2)),
+                                mapData(entry(1d, 2d))),
+                CastTestSpecBuilder.testCastTo(MAP(BIGINT().nullable(), BIGINT().nullable()))
+                        .fromCase(
+                                MAP(INT().nullable(), INT().nullable()),
+                                mapData(entry(1, 2)),
+                                mapData(entry(1L, 2L))),
+                CastTestSpecBuilder.testCastTo(MAP(BIGINT().nullable(), BIGINT().nullable()))
+                        .fromCase(
+                                MAP(INT().nullable(), INT().nullable()),
+                                mapData(entry(1, 2), entry(null, 3), entry(4, null)),
+                                mapData(entry(1L, 2L), entry(null, 3L), entry(4L, null))),
+                CastTestSpecBuilder.testCastTo(MAP(STRING().nullable(), STRING().nullable()))
+                        .fromCase(
+                                MAP(TIMESTAMP().nullable(), DOUBLE().nullable()),
+                                mapData(entry(TIMESTAMP, 123.456)),
+                                mapData(entry(TIMESTAMP_STRING, fromString("123.456")))),
+                CastTestSpecBuilder.testCastTo(MAP(STRING().notNull(), STRING().nullable()))
+                        .fail(
+                                MAP(INT().nullable(), DOUBLE().nullable()),
+                                mapData(entry(null, 1d)),
+                                NullPointerException.class),
+                CastTestSpecBuilder.testCastTo(MAP(STRING().notNull(), STRING().notNull()))
+                        .fail(
+                                MAP(INT().nullable(), DOUBLE().nullable()),
+                                mapData(entry(123, null)),
+                                NullPointerException.class),
+                CastTestSpecBuilder.testCastTo(MULTISET(DOUBLE().notNull()))
+                        .fromCase(
+                                MULTISET(INT().nullable()),
+                                mapData(entry(1, 1)),
+                                mapData(entry(1d, 1))),
+                CastTestSpecBuilder.testCastTo(MULTISET(STRING().notNull()))
+                        .fromCase(
+                                MULTISET(INT().nullable()),
+                                mapData(entry(1, 1)),
+                                mapData(entry(fromString("1"), 1))),
+                CastTestSpecBuilder.testCastTo(MULTISET(FLOAT().nullable()))
+                        .fromCase(
+                                MULTISET(INT().nullable()),
+                                mapData(entry(null, 1)),
+                                mapData(entry(null, 1))),
+                CastTestSpecBuilder.testCastTo(MULTISET(STRING().notNull()))
+                        .fail(
+                                MULTISET(INT().nullable()),
+                                mapData(entry(null, 1)),
+                                NullPointerException.class),
                 CastTestSpecBuilder.testCastTo(
                                 ROW(BIGINT().notNull(), BIGINT(), STRING(), ARRAY(STRING())))
                         .fromCase(
@@ -1174,6 +1224,15 @@ class CastRulesTest {
                                                     fromString("b"),
                                                     fromString("c")
                                                 }))),
+                CastTestSpecBuilder.testCastTo(
+                                ROW(MAP(BIGINT().notNull(), STRING()), MULTISET(STRING())))
+                        .fromCase(
+                                ROW(MAP(INT().notNull(), INT()), MULTISET(TIMESTAMP())),
+                                GenericRowData.of(
+                                        mapData(entry(1, 2)), mapData(entry(TIMESTAMP, 1))),
+                                GenericRowData.of(
+                                        mapData(entry(1L, fromString("2"))),
+                                        mapData(entry(TIMESTAMP_STRING, 1)))),
                 CastTestSpecBuilder.testCastTo(MY_STRUCTURED_TYPE)
                         .fromCase(
                                 ROW(INT().notNull(), INT(), TIME(5), ARRAY(TIMESTAMP())),