You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/05/17 14:38:55 UTC

[flink] 01/02: [FLINK-22666][table] Make structured type's fields more lenient during casting

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 6bf53142f5a0d7445cb3468ee5b7e72809343e03
Author: Timo Walther <tw...@apache.org>
AuthorDate: Mon May 17 09:15:49 2021 +0200

    [FLINK-22666][table] Make structured type's fields more lenient during casting
    
    Compare children individually for anonymous structured types. This
    fixes issues with primitive fields and Scala case classes.
    
    This closes #15935.
---
 .../types/logical/utils/LogicalTypeCasts.java      |  58 +++++++++-
 .../table/types/LogicalTypeCastAvoidanceTest.java  |  42 +++++++-
 .../flink/table/types/LogicalTypeCastsTest.java    |  36 +++++++
 .../runtime/stream/sql/DataStreamScalaITCase.scala | 119 +++++++++++++++++++++
 4 files changed, 250 insertions(+), 5 deletions(-)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
index 8c37f56..d9591c4 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
@@ -35,6 +35,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.table.types.logical.LogicalTypeFamily.BINARY_STRING;
 import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
@@ -315,8 +317,8 @@ public final class LogicalTypeCasts {
         } else if (hasFamily(sourceType, CONSTRUCTED) || hasFamily(targetType, CONSTRUCTED)) {
             return supportsConstructedCasting(sourceType, targetType, allowExplicit);
         } else if (sourceRoot == STRUCTURED_TYPE || targetRoot == STRUCTURED_TYPE) {
-            // inheritance is not supported yet, so structured type must be fully equal
-            return false;
+            return supportsStructuredCasting(
+                    sourceType, targetType, (s, t) -> supportsCasting(s, t, allowExplicit));
         } else if (sourceRoot == RAW || targetRoot == RAW) {
             // the two raw types are not equal (from initial invariant), casting is not possible
             return false;
@@ -334,6 +336,51 @@ public final class LogicalTypeCasts {
         return false;
     }
 
+    private static boolean supportsStructuredCasting(
+            LogicalType sourceType,
+            LogicalType targetType,
+            BiFunction<LogicalType, LogicalType, Boolean> childPredicate) {
+        final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
+        final LogicalTypeRoot targetRoot = targetType.getTypeRoot();
+        if (sourceRoot != STRUCTURED_TYPE || targetRoot != STRUCTURED_TYPE) {
+            return false;
+        }
+        final StructuredType sourceStructuredType = (StructuredType) sourceType;
+        final StructuredType targetStructuredType = (StructuredType) targetType;
+        // non-anonymous structured types must be fully equal
+        if (sourceStructuredType.getObjectIdentifier().isPresent()
+                || targetStructuredType.getObjectIdentifier().isPresent()) {
+            return false;
+        }
+        // for anonymous structured types we are a bit more lenient, if they provide similar fields
+        // e.g. this is necessary when structured types derived from type information and
+        // structured types derived within Table API are slightly different
+        final Class<?> sourceClass = sourceStructuredType.getImplementationClass().orElse(null);
+        final Class<?> targetClass = targetStructuredType.getImplementationClass().orElse(null);
+        if (sourceClass != targetClass) {
+            return false;
+        }
+        final List<String> sourceNames =
+                sourceStructuredType.getAttributes().stream()
+                        .map(StructuredType.StructuredAttribute::getName)
+                        .collect(Collectors.toList());
+        final List<String> targetNames =
+                sourceStructuredType.getAttributes().stream()
+                        .map(StructuredType.StructuredAttribute::getName)
+                        .collect(Collectors.toList());
+        if (!sourceNames.equals(targetNames)) {
+            return false;
+        }
+        final List<LogicalType> sourceChildren = sourceType.getChildren();
+        final List<LogicalType> targetChildren = targetType.getChildren();
+        for (int i = 0; i < sourceChildren.size(); i++) {
+            if (!childPredicate.apply(sourceChildren.get(i), targetChildren.get(i))) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     private static boolean supportsConstructedCasting(
             LogicalType sourceType, LogicalType targetType, boolean allowExplicit) {
         final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
@@ -493,8 +540,11 @@ public final class LogicalTypeCasts {
                 final List<LogicalType> targetChildren = targetType.getChildren();
                 return supportsAvoidingCast(sourceChildren, targetChildren);
             }
-            // structured types should be equal (modulo nullability)
-            return sourceType.equals(targetType) || sourceType.copy(true).equals(targetType);
+            if (sourceType.equals(targetType) || sourceType.copy(true).equals(targetType)) {
+                return true;
+            }
+            return supportsStructuredCasting(
+                    sourceType, targetType, LogicalTypeCasts::supportsAvoidingCast);
         }
 
         @Override
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java
index a994f22..470b720 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java
@@ -230,7 +230,7 @@ public class LogicalTypeCastAvoidanceTest {
                         true
                     },
 
-                    // row and structure type
+                    // row and structured type
                     {
                         RowType.of(new IntType(), new VarCharType()),
                         createUserType("User2", new IntType(), new VarCharType()),
@@ -251,6 +251,46 @@ public class LogicalTypeCastAvoidanceTest {
                         RowType.of(new BigIntType(), new VarCharType()),
                         false
                     },
+
+                    // test slightly different children of anonymous structured types
+                    {
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredType.StructuredAttribute(
+                                                        "f1", new TimestampType()),
+                                                new StructuredType.StructuredAttribute(
+                                                        "diff", new TinyIntType(false))))
+                                .build(),
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredType.StructuredAttribute(
+                                                        "f1", new TimestampType()),
+                                                new StructuredType.StructuredAttribute(
+                                                        "diff", new TinyIntType(true))))
+                                .build(),
+                        true
+                    },
+                    {
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredType.StructuredAttribute(
+                                                        "f1", new TimestampType()),
+                                                new StructuredType.StructuredAttribute(
+                                                        "diff", new TinyIntType(true))))
+                                .build(),
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredType.StructuredAttribute(
+                                                        "f1", new TimestampType()),
+                                                new StructuredType.StructuredAttribute(
+                                                        "diff", new TinyIntType(false))))
+                                .build(),
+                        false
+                    }
                 });
     }
 
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
index e908e6c..64eade9 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
@@ -218,6 +218,42 @@ public class LogicalTypeCastsTest {
                         false,
                         true
                     },
+
+                    // test slightly different children of anonymous structured types
+                    {
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredAttribute("f1", new TimestampType()),
+                                                new StructuredAttribute(
+                                                        "diff", new TinyIntType(false))))
+                                .build(),
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredAttribute("f1", new TimestampType()),
+                                                new StructuredAttribute(
+                                                        "diff", new TinyIntType(true))))
+                                .build(),
+                        true,
+                        true
+                    },
+                    {
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredAttribute("f1", new TimestampType()),
+                                                new StructuredAttribute("diff", new IntType())))
+                                .build(),
+                        StructuredType.newBuilder(Void.class)
+                                .attributes(
+                                        Arrays.asList(
+                                                new StructuredAttribute("f1", new TimestampType()),
+                                                new StructuredAttribute("diff", new TinyIntType())))
+                                .build(),
+                        false,
+                        true
+                    }
                 });
     }
 
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala
new file mode 100644
index 0000000..9cc1ce8
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.stream.sql
+
+import org.apache.flink.streaming.api.scala.{CloseableIterator, DataStream, StreamExecutionEnvironment}
+import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment
+import org.apache.flink.test.util.AbstractTestBase
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.{DataTypes, Table, TableResult}
+import org.apache.flink.table.catalog.{Column, ResolvedSchema}
+import org.apache.flink.table.planner.runtime.stream.sql.DataStreamScalaITCase.{ComplexCaseClass, ImmutableCaseClass}
+import org.apache.flink.types.Row
+import org.apache.flink.util.CollectionUtil
+
+import org.hamcrest.Matchers.containsInAnyOrder
+import org.junit.Assert.{assertEquals, assertThat}
+import org.junit.{Before, Test}
+
+import java.util
+import scala.collection.JavaConverters._
+
+/** Tests for connecting to the Scala [[DataStream]] API. */
+class DataStreamScalaITCase extends AbstractTestBase {
+
+  private var env: StreamExecutionEnvironment = _
+
+  private var tableEnv: StreamTableEnvironment = _
+
+  @Before
+  def before(): Unit = {
+    env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    tableEnv = StreamTableEnvironment.create(env)
+  }
+
+  @Test
+  def testFromAndToDataStreamWithCaseClass(): Unit = {
+    val caseClasses = Array(
+      ComplexCaseClass(42, "hello", ImmutableCaseClass(42.0, b = true)),
+      ComplexCaseClass(42, null, ImmutableCaseClass(42.0, b = false)))
+
+    val dataStream = env.fromElements(caseClasses: _*)
+
+    val table = tableEnv.fromDataStream(dataStream)
+
+    testSchema(
+      table,
+      Column.physical("c", DataTypes.INT().notNull().bridgedTo(classOf[Int])),
+      Column.physical("a", DataTypes.STRING()),
+      Column.physical(
+        "p",
+        DataTypes.STRUCTURED(
+          classOf[ImmutableCaseClass],
+          DataTypes.FIELD(
+            "d",
+            DataTypes.DOUBLE().notNull()), // serializer doesn't support null
+          DataTypes.FIELD(
+            "b",
+            DataTypes.BOOLEAN().notNull().bridgedTo(classOf[Boolean]))).notNull()))
+
+    testResult(
+      table.execute(),
+      Row.of(Int.box(42), "hello", ImmutableCaseClass(42.0, b = true)),
+      Row.of(Int.box(42), null, ImmutableCaseClass(42.0, b = false)))
+
+    val resultStream = tableEnv.toDataStream(table, classOf[ComplexCaseClass])
+
+    testResult(resultStream, caseClasses: _*)
+  }
+
+  // --------------------------------------------------------------------------------------------
+  // Helper methods
+  // --------------------------------------------------------------------------------------------
+
+  private def testSchema(table: Table, expectedColumns: Column*): Unit = {
+    assertEquals(ResolvedSchema.of(expectedColumns: _*), table.getResolvedSchema)
+  }
+
+  private def testResult(result: TableResult, expectedRows: Row*): Unit = {
+    val actualRows: util.List[Row] = CollectionUtil.iteratorToList(result.collect)
+    assertThat(actualRows, containsInAnyOrder(expectedRows: _*))
+  }
+
+  private def testResult[T](dataStream: DataStream[T], expectedResult: T*): Unit = {
+    var iterator: CloseableIterator[T] = null
+    try {
+      iterator = dataStream.executeAndCollect()
+      val list: util.List[T] = iterator.toList.asJava
+      assertThat(list, containsInAnyOrder(expectedResult: _*))
+    } finally {
+      if (iterator != null) {
+        iterator.close()
+      }
+    }
+  }
+}
+
+object DataStreamScalaITCase {
+
+  case class ComplexCaseClass(var c: Int, var a: String, var p: ImmutableCaseClass)
+
+  case class ImmutableCaseClass(d: java.lang.Double, b: Boolean)
+}