You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/12/31 11:56:38 UTC

[flink] 02/02: [FLINK-25187][table-planner] Apply padding when CASTing to BINARY()

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 49acb2723eda8ebd3fc59af19d4bc0abb9f1a318
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Tue Dec 21 16:42:23 2021 +0200

    [FLINK-25187][table-planner] Apply padding when CASTing to BINARY(<length>)
    
    Similarly to `CHAR(<length>)` when casting to a `BINARY(<length>)`
    apply padding with 0 bytes to the right so that the resulting `byte[]`
    matches exaxctly the specified length.
    
    This closes #18162.
---
 .../functions/casting/BinaryToBinaryCastRule.java  | 41 +++++++++++++-----
 .../functions/casting/RawToBinaryCastRule.java     | 35 +++++++++------
 .../functions/casting/StringToBinaryCastRule.java  | 34 +++++++++------
 .../planner/functions/CastFunctionITCase.java      |  5 +++
 .../planner/functions/CastFunctionMiscITCase.java  | 10 +++++
 .../planner/functions/casting/CastRulesTest.java   | 50 +++++++++++++---------
 6 files changed, 116 insertions(+), 59 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToBinaryCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToBinaryCastRule.java
index 9887818..72fbcfc 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToBinaryCastRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToBinaryCastRule.java
@@ -18,8 +18,10 @@
 
 package org.apache.flink.table.planner.functions.casting;
 
+import org.apache.flink.table.types.logical.BinaryType;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeFamily;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 
 import java.util.Arrays;
@@ -47,7 +49,7 @@ class BinaryToBinaryCastRule extends AbstractExpressionCodeGeneratorCastRule<byt
     ((byte[])(inputValue))
 
     // new behavior
-    ((((byte[])(inputValue)).length <= 2) ? (((byte[])(inputValue))) : (java.util.Arrays.copyOfRange(((byte[])(inputValue)), 0, 2)))
+    ((((byte[])(inputValue)).length == 2) ? (((byte[])(inputValue))) : (java.util.Arrays.copyOf(((byte[])(inputValue)), 2)))
 
     */
 
@@ -60,18 +62,35 @@ class BinaryToBinaryCastRule extends AbstractExpressionCodeGeneratorCastRule<byt
         int inputLength = LogicalTypeChecks.getLength(inputLogicalType);
         int targetLength = LogicalTypeChecks.getLength(targetLogicalType);
 
-        if (context.legacyBehaviour()) {
+        if (context.legacyBehaviour()
+                || ((!couldTrim(targetLength)
+                                // Assume input length is respected by the source
+                                || (inputLength <= targetLength))
+                        && !couldPad(targetLogicalType, targetLength))) {
             return inputTerm;
         } else {
-            // Assume input length is respected by the source
-            if (inputLength <= targetLength) {
-                return inputTerm;
-            } else {
-                return ternaryOperator(
-                        arrayLength(inputTerm) + " <= " + targetLength,
-                        inputTerm,
-                        staticCall(Arrays.class, "copyOfRange", inputTerm, 0, targetLength));
-            }
+            return ternaryOperator(
+                    arrayLength(inputTerm) + " == " + targetLength,
+                    inputTerm,
+                    staticCall(Arrays.class, "copyOf", inputTerm, targetLength));
         }
     }
+
+    static boolean couldTrim(int targetLength) {
+        return targetLength < BinaryType.MAX_LENGTH;
+    }
+
+    static boolean couldPad(LogicalType targetType, int targetLength) {
+        return targetType.is(LogicalTypeRoot.BINARY) && targetLength < BinaryType.MAX_LENGTH;
+    }
+
+    static void trimOrPadByteArray(
+            String returnVariable,
+            int targetLength,
+            String deserializedByteArrayTerm,
+            CastRuleUtils.CodeWriter writer) {
+        writer.assignStmt(
+                returnVariable,
+                staticCall(Arrays.class, "copyOf", deserializedByteArrayTerm, targetLength));
+    }
 }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToBinaryCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToBinaryCastRule.java
index eecf384..8dc2111 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToBinaryCastRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToBinaryCastRule.java
@@ -23,12 +23,12 @@ import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 
-import java.util.Arrays;
-
 import static org.apache.flink.table.codesplit.CodeSplitUtil.newName;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.couldPad;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.couldTrim;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.trimOrPadByteArray;
 import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.arrayLength;
 import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall;
-import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall;
 
 /** {@link LogicalTypeRoot#RAW} to {@link LogicalTypeFamily#BINARY_STRING} cast rule. */
 class RawToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object, byte[]> {
@@ -61,7 +61,7 @@ class RawToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object,
         if (deserializedByteArray$76.length <= 3) {
             result$291 = deserializedByteArray$76;
         } else {
-            result$291 = java.util.Arrays.copyOfRange(deserializedByteArray$76, 0, 3);
+            result$291 = java.util.Arrays.copyOf(deserializedByteArray$76, 3);
         }
         isNull$290 = result$291 == null;
     } else {
@@ -83,7 +83,8 @@ class RawToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object,
         final String typeSerializer = context.declareTypeSerializer(inputLogicalType);
         final String deserializedByteArrayTerm = newName("deserializedByteArray");
 
-        if (context.legacyBehaviour()) {
+        if (context.legacyBehaviour()
+                || !(couldTrim(targetLength) || (couldPad(targetLogicalType, targetLength)))) {
             return new CastRuleUtils.CodeWriter()
                     .assignStmt(returnVariable, methodCall(inputTerm, "toBytes", typeSerializer))
                     .toString();
@@ -95,18 +96,24 @@ class RawToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object,
                             methodCall(inputTerm, "toBytes", typeSerializer))
                     .ifStmt(
                             arrayLength(deserializedByteArrayTerm) + " <= " + targetLength,
-                            thenWriter ->
+                            thenWriter -> {
+                                if (couldPad(targetLogicalType, targetLength)) {
+                                    trimOrPadByteArray(
+                                            returnVariable,
+                                            targetLength,
+                                            deserializedByteArrayTerm,
+                                            thenWriter);
+                                } else {
                                     thenWriter.assignStmt(
-                                            returnVariable, deserializedByteArrayTerm),
+                                            returnVariable, deserializedByteArrayTerm);
+                                }
+                            },
                             elseWriter ->
-                                    elseWriter.assignStmt(
+                                    trimOrPadByteArray(
                                             returnVariable,
-                                            staticCall(
-                                                    Arrays.class,
-                                                    "copyOfRange",
-                                                    deserializedByteArrayTerm,
-                                                    0,
-                                                    targetLength)))
+                                            targetLength,
+                                            deserializedByteArrayTerm,
+                                            elseWriter))
                     .toString();
         }
     }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/StringToBinaryCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/StringToBinaryCastRule.java
index 680f0a8..a0c7bb0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/StringToBinaryCastRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/StringToBinaryCastRule.java
@@ -23,12 +23,12 @@ import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 
-import java.util.Arrays;
-
 import static org.apache.flink.table.codesplit.CodeSplitUtil.newName;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.couldPad;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.couldTrim;
+import static org.apache.flink.table.planner.functions.casting.BinaryToBinaryCastRule.trimOrPadByteArray;
 import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.arrayLength;
 import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall;
-import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall;
 
 /**
  * {@link LogicalTypeFamily#CHARACTER_STRING} to {@link LogicalTypeFamily#BINARY_STRING} cast rule.
@@ -63,7 +63,7 @@ class StringToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Stri
         if (byteArrayTerm$0.length <= 2) {
             result$1 = byteArrayTerm$0;
         } else {
-            result$1 = java.util.Arrays.copyOfRange(byteArrayTerm$0, 0, 2);
+            result$1 = java.util.Arrays.copyOf(byteArrayTerm$0, 2);
         }
         isNull$0 = result$1 == null;
     } else {
@@ -83,7 +83,8 @@ class StringToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Stri
 
         final String byteArrayTerm = newName("byteArrayTerm");
 
-        if (context.legacyBehaviour()) {
+        if (context.legacyBehaviour()
+                || !(couldTrim(targetLength) || couldPad(targetLogicalType, targetLength))) {
             return new CastRuleUtils.CodeWriter()
                     .assignStmt(returnVariable, methodCall(inputTerm, "toBytes"))
                     .toString();
@@ -92,16 +93,23 @@ class StringToBinaryCastRule extends AbstractNullAwareCodeGeneratorCastRule<Stri
                     .declStmt(byte[].class, byteArrayTerm, methodCall(inputTerm, "toBytes"))
                     .ifStmt(
                             arrayLength(byteArrayTerm) + " <= " + targetLength,
-                            thenWriter -> thenWriter.assignStmt(returnVariable, byteArrayTerm),
+                            thenWriter -> {
+                                if (couldPad(targetLogicalType, targetLength)) {
+                                    trimOrPadByteArray(
+                                            returnVariable,
+                                            targetLength,
+                                            byteArrayTerm,
+                                            thenWriter);
+                                } else {
+                                    thenWriter.assignStmt(returnVariable, byteArrayTerm);
+                                }
+                            },
                             elseWriter ->
-                                    elseWriter.assignStmt(
+                                    trimOrPadByteArray(
                                             returnVariable,
-                                            staticCall(
-                                                    Arrays.class,
-                                                    "copyOfRange",
-                                                    byteArrayTerm,
-                                                    0,
-                                                    targetLength)))
+                                            targetLength,
+                                            byteArrayTerm,
+                                            elseWriter))
                     .toString();
         }
     }
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
index b2440f0..42c042d 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java
@@ -301,12 +301,17 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase {
                         .fromCase(CHAR(3), "foo", new byte[] {102, 111})
                         .fromCase(VARCHAR(5), "Flink", new byte[] {70, 108})
                         .fromCase(STRING(), "Apache", new byte[] {65, 112})
+                        .fromCase(VARCHAR(5), "f", new byte[] {102, 0})
+                        .fromCase(STRING(), "f", new byte[] {102, 0})
                         // Not supported - no fix
                         .fail(BOOLEAN(), true)
                         //
                         .fromCase(BINARY(2), DEFAULT_BINARY, DEFAULT_BINARY)
                         .fromCase(VARBINARY(3), DEFAULT_VARBINARY, new byte[] {0, 1})
                         .fromCase(BYTES(), DEFAULT_BYTES, new byte[] {0, 1})
+                        .fromCase(BINARY(1), new byte[] {111}, new byte[] {111, 0})
+                        .fromCase(VARBINARY(1), new byte[] {111}, new byte[] {111, 0})
+                        .fromCase(BYTES(), new byte[] {11}, new byte[] {11, 0})
                         // Not supported - no fix
                         .fail(DECIMAL(5, 3), 12.345)
                         .fail(TINYINT(), DEFAULT_NEGATIVE_TINY_INT)
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionMiscITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionMiscITCase.java
index 6273a0c..3ca2021 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionMiscITCase.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionMiscITCase.java
@@ -169,6 +169,16 @@ public class CastFunctionMiscITCase extends BuiltInFunctionTestBase {
                                 BYTES()),
                 TestSpec.forFunction(
                                 BuiltInFunctionDefinitions.CAST,
+                                "cast from RAW(Integer) to BINARY(6)")
+                        .onFieldsWithData(123456)
+                        .andDataTypes(INT())
+                        .withFunction(IntegerToRaw.class)
+                        .testTableApiResult(
+                                call("IntegerToRaw", $("f0")).cast(BINARY(6)),
+                                new byte[] {0, 1, -30, 64, 0, 0},
+                                BINARY(6)),
+                TestSpec.forFunction(
+                                BuiltInFunctionDefinitions.CAST,
                                 "cast from RAW(UserPojo) to VARBINARY")
                         .onFieldsWithData(123456, "Flink")
                         .andDataTypes(INT(), STRING())
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
index 5e7dfdf..301ab4c 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java
@@ -632,7 +632,7 @@ class CastRulesTest {
                                 fromString("(null,abc)"))
                         .fromCase(ROW(), GenericRowData.of(), fromString("()"))
                         .fromCase(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE),
                                 RawValueData.fromObject(
                                         LocalDateTime.parse("2020-11-11T18:08:01.123")),
                                 fromString("2020-11-11T18:08:01.123"))
@@ -749,19 +749,21 @@ class CastRulesTest {
                                 null,
                                 EMPTY_UTF8)
                         .fromCase(
-                                RAW(LocalDate.class, new LocalDateSerializer()),
+                                RAW(LocalDate.class, LocalDateSerializer.INSTANCE),
                                 RawValueData.fromObject(LocalDate.parse("2020-12-09")),
                                 fromString("2020-12-09  "))
                         .fromCaseLegacy(
-                                RAW(LocalDate.class, new LocalDateSerializer()),
+                                RAW(LocalDate.class, LocalDateSerializer.INSTANCE),
                                 RawValueData.fromObject(LocalDate.parse("2020-12-09")),
                                 fromString("2020-12-09"))
                         .fromCase(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE)
+                                        .nullable(),
                                 null,
                                 EMPTY_UTF8)
                         .fromCaseLegacy(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE)
+                                        .nullable(),
                                 null,
                                 EMPTY_UTF8),
                 CastTestSpecBuilder.testCastTo(VARCHAR(3))
@@ -906,21 +908,23 @@ class CastRulesTest {
                                 null,
                                 EMPTY_UTF8)
                         .fromCase(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE),
                                 RawValueData.fromObject(
                                         LocalDateTime.parse("2020-11-11T18:08:01.123")),
                                 fromString("202"))
                         .fromCaseLegacy(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE),
                                 RawValueData.fromObject(
                                         LocalDateTime.parse("2020-11-11T18:08:01.123")),
                                 fromString("2020-11-11T18:08:01.123"))
                         .fromCase(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE)
+                                        .nullable(),
                                 null,
                                 EMPTY_UTF8)
                         .fromCaseLegacy(
-                                RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(),
+                                RAW(LocalDateTime.class, LocalDateTimeSerializer.INSTANCE)
+                                        .nullable(),
                                 null,
                                 EMPTY_UTF8),
                 CastTestSpecBuilder.testCastTo(BOOLEAN())
@@ -955,27 +959,31 @@ class CastRulesTest {
                         .fromCase(FLOAT(), 1.1234f, true)
                         .fromCase(DOUBLE(), 0.0d, false)
                         .fromCase(DOUBLE(), -0.12345678d, true),
-                CastTestSpecBuilder.testCastTo(BINARY(2))
+                CastTestSpecBuilder.testCastTo(BINARY(4))
+                        .fromCase(CHAR(3), fromString("foo"), new byte[] {102, 111, 111, 0})
                         .fromCaseLegacy(CHAR(3), fromString("foo"), new byte[] {102, 111, 111})
-                        .fromCase(CHAR(3), fromString("foo"), new byte[] {102, 111})
-                        .fromCase(CHAR(1), fromString("f"), new byte[] {102})
-                        .fromCase(CHAR(3), fromString("f"), new byte[] {102})
-                        .fromCase(VARCHAR(5), fromString("Flink"), new byte[] {70, 108})
+                        .fromCase(CHAR(1), fromString("f"), new byte[] {102, 0, 0, 0})
+                        .fromCaseLegacy(CHAR(1), fromString("f"), new byte[] {102})
+                        .fromCase(CHAR(3), fromString("f"), new byte[] {102, 0, 0, 0})
+                        .fromCaseLegacy(CHAR(3), fromString("f"), new byte[] {102})
+                        .fromCase(VARCHAR(5), fromString("Flink"), new byte[] {70, 108, 105, 110})
                         .fromCaseLegacy(
                                 VARCHAR(5),
                                 fromString("Flink"),
                                 new byte[] {70, 108, 105, 110, 107})
-                        .fromCase(STRING(), fromString("Apache"), new byte[] {65, 112})
+                        .fromCase(STRING(), fromString("Apache"), new byte[] {65, 112, 97, 99})
                         .fromCaseLegacy(
                                 STRING(),
                                 fromString("Apache"),
                                 new byte[] {65, 112, 97, 99, 104, 101})
-                        // We assume that the input length is respected, therefore, no trimming is
-                        // applied
-                        .fromCase(BINARY(2), new byte[] {1, 2, 3}, new byte[] {1, 2, 3})
-                        .fromCaseLegacy(BINARY(2), new byte[] {1, 2, 3}, new byte[] {1, 2, 3})
-                        .fromCase(VARBINARY(2), new byte[] {1, 2, 3}, new byte[] {1, 2, 3})
-                        .fromCaseLegacy(VARBINARY(2), new byte[] {1, 2, 3}, new byte[] {1, 2, 3}),
+                        .fromCase(STRING(), fromString("bar"), new byte[] {98, 97, 114, 0})
+                        .fromCaseLegacy(STRING(), fromString("bar"), new byte[] {98, 97, 114})
+                        .fromCase(BINARY(2), new byte[] {1, 2}, new byte[] {1, 2, 0, 0})
+                        .fromCaseLegacy(BINARY(2), new byte[] {1, 2}, new byte[] {1, 2})
+                        .fromCase(VARBINARY(3), new byte[] {1, 2, 3}, new byte[] {1, 2, 3, 0})
+                        .fromCaseLegacy(VARBINARY(3), new byte[] {1, 2, 3}, new byte[] {1, 2, 3})
+                        .fromCase(BYTES(), new byte[] {1, 2, 3}, new byte[] {1, 2, 3, 0})
+                        .fromCaseLegacy(BYTES(), new byte[] {1, 2, 3}, new byte[] {1, 2, 3}),
                 CastTestSpecBuilder.testCastTo(VARBINARY(4))
                         .fromCase(CHAR(3), fromString("foo"), new byte[] {102, 111, 111})
                         .fromCaseLegacy(