You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/08/16 13:10:45 UTC

[flink] branch master updated: [FLINK-27015][hive] Fix exception for casting timestamp to decimal in Hive dialect (#20571)

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 435dab64f61 [FLINK-27015][hive] Fix exception for casting timestamp to decimal in Hive dialect (#20571)
435dab64f61 is described below

commit 435dab64f6171dacdf90c759554dc3ef085d4ad5
Author: yuxia Luo <lu...@alumni.sjtu.edu.cn>
AuthorDate: Tue Aug 16 21:10:38 2022 +0800

    [FLINK-27015][hive] Fix exception for casting timestamp to decimal in Hive dialect (#20571)
---
 .../apache/flink/table/module/hive/HiveModule.java |  9 ++
 .../hive/udf/generic/HiveGenericUDFToDecimal.java  | 96 ++++++++++++++++++++++
 .../delegation/hive/HiveParserDMLHelper.java       | 23 ++++++
 .../hive/HiveParserRexNodeConverter.java           | 25 +++++-
 .../planner/delegation/hive/HiveParserUtils.java   |  6 ++
 .../connectors/hive/HiveDialectQueryITCase.java    | 51 ++++++++++++
 .../flink/table/module/hive/HiveModuleTest.java    |  4 +-
 7 files changed, 211 insertions(+), 3 deletions(-)

diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
index 66908b5df72..6e8a21f270f 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
@@ -29,6 +29,7 @@ import org.apache.flink.table.module.hive.udf.generic.GenericUDFLegacyGroupingID
 import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFArrayAccessStructField;
 import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFGrouping;
 import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFInternalInterval;
+import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFToDecimal;
 import org.apache.flink.util.StringUtils;
 
 import org.apache.hadoop.hive.ql.exec.FunctionInfo;
@@ -114,6 +115,7 @@ public class HiveModule implements Module {
             functionNames.add("grouping");
             functionNames.add(GenericUDFLegacyGroupingID.NAME);
             functionNames.add(HiveGenericUDFArrayAccessStructField.NAME);
+            functionNames.add(HiveGenericUDFToDecimal.NAME);
         }
         return functionNames;
     }
@@ -152,6 +154,13 @@ public class HiveModule implements Module {
                             name, HiveGenericUDFArrayAccessStructField.class.getName(), context));
         }
 
+        // We add a custom to_decimal function. Refer to the implementation for more details.
+        if (name.equalsIgnoreCase(HiveGenericUDFToDecimal.NAME)) {
+            return Optional.of(
+                    factory.createFunctionDefinitionFromHiveFunction(
+                            name, HiveGenericUDFToDecimal.class.getName(), context));
+        }
+
         Optional<FunctionInfo> info = hiveShim.getBuiltInFunctionInfo(name);
 
         return info.map(
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/udf/generic/HiveGenericUDFToDecimal.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/udf/generic/HiveGenericUDFToDecimal.java
new file mode 100644
index 00000000000..ec5719d6b20
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/udf/generic/HiveGenericUDFToDecimal.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.module.hive.udf.generic;
+
+import org.apache.flink.table.planner.delegation.hive.HiveParserUtils;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableHiveDecimalObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
+
+/**
+ * Counterpart of Hive's org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal, which removes
+ * the method #setTypeInfo() used to pass target type for we have no way to pass target type to it
+ * in Flink. Instead, the target type will be passed to the function as the second parameter.
+ */
+public class HiveGenericUDFToDecimal extends GenericUDF {
+
+    public static final String NAME = "flink_hive_to_decimal";
+
+    private transient PrimitiveObjectInspectorConverter.HiveDecimalConverter bdConverter;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
+        if (arguments.length != 2) {
+            throw new UDFArgumentLengthException(
+                    "The function flink_hive_to_decimal requires exactly two arguments, got "
+                            + arguments.length);
+        }
+        PrimitiveObjectInspector srcOI;
+        try {
+            srcOI = (PrimitiveObjectInspector) arguments[0];
+        } catch (ClassCastException e) {
+            throw new UDFArgumentException(
+                    "The function flink_hive_to_decimal takes only primitive types as first argument.");
+        }
+
+        HiveDecimalObjectInspector targetOI;
+        try {
+            targetOI = (HiveDecimalObjectInspector) arguments[1];
+        } catch (ClassCastException e) {
+            throw new UDFArgumentException(
+                    "The function flink_hive_to_decimal takes only decimal types as second argument.");
+        }
+
+        DecimalTypeInfo returnTypeInfo =
+                new DecimalTypeInfo(targetOI.precision(), targetOI.scale());
+
+        bdConverter =
+                new PrimitiveObjectInspectorConverter.HiveDecimalConverter(
+                        srcOI,
+                        (SettableHiveDecimalObjectInspector)
+                                PrimitiveObjectInspectorFactory
+                                        .getPrimitiveWritableConstantObjectInspector(
+                                                returnTypeInfo, null));
+
+        return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(returnTypeInfo);
+    }
+
+    @Override
+    public Object evaluate(DeferredObject[] arguments) throws HiveException {
+        Object o0 = arguments[0].get();
+        if (o0 == null) {
+            return null;
+        }
+        return bdConverter.convert(o0);
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return HiveParserUtils.getStandardDisplayString("flink_hive_to_decimal", children);
+    }
+}
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserDMLHelper.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserDMLHelper.java
index 06d0e38220b..280742dc1fa 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserDMLHelper.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserDMLHelper.java
@@ -31,6 +31,7 @@ import org.apache.flink.table.catalog.UnresolvedIdentifier;
 import org.apache.flink.table.catalog.hive.HiveCatalog;
 import org.apache.flink.table.catalog.hive.factories.HiveCatalogFactoryOptions;
 import org.apache.flink.table.factories.FactoryUtil;
+import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFToDecimal;
 import org.apache.flink.table.operations.Operation;
 import org.apache.flink.table.operations.QueryOperation;
 import org.apache.flink.table.operations.SinkModifyOperation;
@@ -61,6 +62,7 @@ import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlOperator;
 import org.apache.hadoop.hive.metastore.api.FieldSchema;
 import org.apache.hadoop.hive.ql.exec.FunctionInfo;
 import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
@@ -510,6 +512,27 @@ public class HiveParserDMLHelper {
         if (funcConverter == null) {
             return rexBuilder.makeCast(targetCalType, srcRex);
         }
+
+        if (HiveParserUtils.isFromTimeStampToDecimal(srcRex.getType(), targetCalType)) {
+            // special case for cast timestamp to decimal for Flink don't support cast from
+            // TIMESTAMP type to NUMERIC type.
+            // use custom to_decimal function to cast, which is consistent with Hive.
+            SqlOperator castOperator =
+                    HiveParserSqlFunctionConverter.getCalciteFn(
+                            HiveGenericUDFToDecimal.NAME,
+                            Arrays.asList(srcRex.getType(), targetCalType),
+                            targetCalType,
+                            false,
+                            funcConverter);
+            RexCall cast =
+                    (RexCall)
+                            rexBuilder.makeCall(
+                                    castOperator,
+                                    srcRex,
+                                    rexBuilder.makeNullLiteral(targetCalType));
+            return cast.accept(funcConverter);
+        }
+
         // hive implements CAST with UDFs
         String udfName = TypeInfoUtils.getBaseName(targetHiveType.getTypeName());
         FunctionInfo functionInfo;
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java
index d7910ef38b7..17ea24a6e4a 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java
@@ -21,6 +21,7 @@ package org.apache.flink.table.planner.delegation.hive;
 import org.apache.flink.table.catalog.hive.client.HiveShim;
 import org.apache.flink.table.catalog.hive.util.HiveReflectionUtils;
 import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFArrayAccessStructField;
+import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFToDecimal;
 import org.apache.flink.table.planner.delegation.hive.copy.HiveASTParseUtils;
 import org.apache.flink.table.planner.delegation.hive.copy.HiveParserExprNodeDescUtils;
 import org.apache.flink.table.planner.delegation.hive.copy.HiveParserExprNodeSubQueryDesc;
@@ -764,7 +765,29 @@ public class HiveParserRexNodeConverter {
             throws SemanticException {
         GenericUDF udf = func.getGenericUDF();
         if (isExplicitCast(udf) && childRexNodeLst != null && childRexNodeLst.size() == 1) {
-            // we cannot handle SettableUDF at the moment so we call calcite to do the cast in that
+            RelDataType targetType =
+                    HiveParserTypeConverter.convert(func.getTypeInfo(), cluster.getTypeFactory());
+            if (HiveParserUtils.isFromTimeStampToDecimal(
+                    childRexNodeLst.get(0).getType(), targetType)) {
+                // special case for cast timestamp to decimal for Flink don't support cast from
+                // TIMESTAMP type to NUMERIC type.
+                // use custom to_decimal function to cast, which is consistent with Hive.
+                SqlOperator castOperator =
+                        HiveParserSqlFunctionConverter.getCalciteFn(
+                                HiveGenericUDFToDecimal.NAME,
+                                Arrays.asList(childRexNodeLst.get(0).getType(), targetType),
+                                targetType,
+                                false,
+                                funcConverter);
+                return cluster.getRexBuilder()
+                        .makeCall(
+                                castOperator,
+                                childRexNodeLst.get(0),
+                                cluster.getRexBuilder().makeNullLiteral(targetType));
+            }
+
+            // we cannot handle SettableUDF at the moment so we call calcite to do the cast in
+            // that
             // case, otherwise we use hive functions to achieve better compatibility
             if (udf instanceof SettableUDF
                     || !funcConverter.hasOverloadedOp(
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
index 00770732f48..54b8b8be1df 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
@@ -88,6 +88,7 @@ import org.apache.calcite.sql.SqlSyntax;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.sql.validate.SqlNameMatchers;
 import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction;
@@ -1673,4 +1674,9 @@ public class HiveParserUtils {
                 type,
                 name);
     }
+
+    public static boolean isFromTimeStampToDecimal(RelDataType srcType, RelDataType targetType) {
+        return srcType.getSqlTypeName().equals(SqlTypeName.TIMESTAMP)
+                && targetType.getSqlTypeName().equals(SqlTypeName.DECIMAL);
+    }
 }
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
index 6ccd3b0ff86..890c000f090 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
@@ -23,6 +23,9 @@ import org.apache.flink.table.api.SqlDialect;
 import org.apache.flink.table.api.TableEnvironment;
 import org.apache.flink.table.catalog.hive.HiveCatalog;
 import org.apache.flink.table.catalog.hive.HiveTestUtils;
+import org.apache.flink.table.catalog.hive.client.HiveShim;
+import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
+import org.apache.flink.table.functions.hive.conversion.HiveInspectors;
 import org.apache.flink.table.module.CoreModule;
 import org.apache.flink.table.module.hive.HiveModule;
 import org.apache.flink.table.planner.delegation.hive.HiveParserUtils;
@@ -30,6 +33,7 @@ import org.apache.flink.types.Row;
 import org.apache.flink.util.CollectionUtil;
 import org.apache.flink.util.FileUtils;
 
+import org.apache.hadoop.hive.common.type.HiveDecimal;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
@@ -38,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
 import org.junit.BeforeClass;
 import org.junit.ComparisonFailure;
 import org.junit.Test;
@@ -47,6 +53,7 @@ import java.io.File;
 import java.io.FileReader;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.sql.Timestamp;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -754,6 +761,50 @@ public class HiveDialectQueryITCase {
         }
     }
 
+    @Test
+    public void testCastTimeStampToDecimal() throws Exception {
+        try {
+            String timestamp = "2012-12-19 11:12:19.1234567";
+            // timestamp's behavior is different between hive2 and hive3, so
+            // use HiveShim in this test to hide such difference
+            HiveShim hiveShim = HiveShimLoader.loadHiveShim(HiveShimLoader.getHiveVersion());
+            Object hiveTimestamp = hiveShim.toHiveTimestamp(Timestamp.valueOf(timestamp));
+            TimestampObjectInspector timestampObjectInspector =
+                    (TimestampObjectInspector)
+                            HiveInspectors.getObjectInspector(TypeInfoFactory.timestampTypeInfo);
+
+            HiveDecimal expectTimeStampDecimal =
+                    timestampObjectInspector
+                            .getPrimitiveWritableObject(hiveTimestamp)
+                            .getHiveDecimal();
+
+            // test cast timestamp to decimal explicitly
+            List<Row> results =
+                    CollectionUtil.iteratorToList(
+                            tableEnv.executeSql(
+                                            String.format(
+                                                    "select cast(cast('%s' as timestamp) as decimal(30,8))",
+                                                    timestamp))
+                                    .collect());
+            assertThat(results.toString())
+                    .isEqualTo(String.format("[+I[%s]]", expectTimeStampDecimal.toFormatString(8)));
+
+            // test insert timestamp type to decimal type directly
+            tableEnv.executeSql("create table t1 (c1 DECIMAL(38,6))");
+            tableEnv.executeSql("create table t2 (c2 TIMESTAMP)");
+            tableEnv.executeSql(String.format("insert into t2 values('%s')", timestamp)).await();
+            tableEnv.executeSql("insert into t1 select * from t2").await();
+            results =
+                    CollectionUtil.iteratorToList(
+                            tableEnv.executeSql("select * from t1").collect());
+            assertThat(results.toString())
+                    .isEqualTo(String.format("[+I[%s]]", expectTimeStampDecimal.toFormatString(6)));
+        } finally {
+            tableEnv.executeSql("drop table t1");
+            tableEnv.executeSql("drop table t2");
+        }
+    }
+
     private void runQFile(File qfile) throws Exception {
         QTest qTest = extractQTest(qfile);
         for (int i = 0; i < qTest.statements.size(); i++) {
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/module/hive/HiveModuleTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/module/hive/HiveModuleTest.java
index f7933d9dc5d..842e71eb82c 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/module/hive/HiveModuleTest.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/module/hive/HiveModuleTest.java
@@ -74,10 +74,10 @@ public class HiveModuleTest {
     private void verifyNumBuiltInFunctions(String hiveVersion, HiveModule hiveModule) {
         switch (hiveVersion) {
             case HIVE_VERSION_V2_3_9:
-                assertThat(hiveModule.listFunctions()).hasSize(275);
+                assertThat(hiveModule.listFunctions()).hasSize(276);
                 break;
             case HIVE_VERSION_V3_1_1:
-                assertThat(hiveModule.listFunctions()).hasSize(294);
+                assertThat(hiveModule.listFunctions()).hasSize(295);
                 break;
             default:
                 fail("Unknown test version " + hiveVersion);