You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ku...@apache.org on 2019/08/07 15:51:17 UTC

[flink] 01/05: [FLINK-13225][table-planner-blink] Fix type inference for hive udf

This is an automated email from the ASF dual-hosted git repository.

kurt pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit e08117f89ba012c37f70c4aad99d569d8a9ba2b6
Author: JingsongLi <lz...@aliyun.com>
AuthorDate: Sun Jul 28 20:16:14 2019 +0800

    [FLINK-13225][table-planner-blink] Fix type inference for hive udf
---
 .../catalog/FunctionCatalogOperatorTable.java      | 14 +++-
 .../planner/functions/utils/HiveFunctionUtils.java | 80 ++++++++++++++++++++
 .../functions/utils/HiveScalarSqlFunction.java     | 85 ++++++++++++++++++++++
 .../table/planner/codegen/ExprCodeGenerator.scala  | 18 ++++-
 .../functions/utils/ScalarSqlFunction.scala        |  9 ++-
 .../functions/utils/UserDefinedFunctionUtils.scala |  4 +
 6 files changed, 205 insertions(+), 5 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java
index 87a7bcb..ddf8f60 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java
@@ -26,6 +26,7 @@ import org.apache.flink.table.functions.FunctionDefinition;
 import org.apache.flink.table.functions.ScalarFunctionDefinition;
 import org.apache.flink.table.functions.TableFunctionDefinition;
 import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.functions.utils.HiveScalarSqlFunction;
 import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils;
 import org.apache.flink.table.types.utils.TypeConversions;
 
@@ -40,6 +41,8 @@ import org.apache.calcite.sql.validate.SqlNameMatcher;
 import java.util.List;
 import java.util.Optional;
 
+import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.isHiveFunc;
+
 /**
  * Thin adapter between {@link SqlOperatorTable} and {@link FunctionCatalog}.
  */
@@ -92,7 +95,16 @@ public class FunctionCatalogOperatorTable implements SqlOperatorTable {
 		if (functionDefinition instanceof AggregateFunctionDefinition) {
 			return convertAggregateFunction(name, (AggregateFunctionDefinition) functionDefinition);
 		} else if (functionDefinition instanceof ScalarFunctionDefinition) {
-			return convertScalarFunction(name, (ScalarFunctionDefinition) functionDefinition);
+			ScalarFunctionDefinition def = (ScalarFunctionDefinition) functionDefinition;
+			if (isHiveFunc(def.getScalarFunction())) {
+				return Optional.of(new HiveScalarSqlFunction(
+						name,
+						name,
+						def.getScalarFunction(),
+						typeFactory));
+			} else {
+				return convertScalarFunction(name, def);
+			}
 		} else if (functionDefinition instanceof TableFunctionDefinition &&
 				category != null &&
 				category.isTableFunction()) {
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java
new file mode 100644
index 0000000..13a82cb
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.utils;
+
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+import org.apache.calcite.rel.type.RelDataType;
+
+import java.io.Serializable;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+
+import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType;
+
+/**
+ * Hack utils for hive function.
+ */
+public class HiveFunctionUtils {
+
+	public static boolean isHiveFunc(Object function) {
+		try {
+			getSetArgsMethod(function);
+			return true;
+		} catch (NoSuchMethodException e) {
+			return false;
+		}
+	}
+
+	private static Method getSetArgsMethod(Object function) throws NoSuchMethodException {
+		return function.getClass().getMethod(
+				"setArgumentTypesAndConstants", Object[].class, DataType[].class);
+
+	}
+
+	static Serializable invokeSetArgs(
+			Serializable function, Object[] constantArguments, LogicalType[] argTypes) {
+		try {
+			// See hive HiveFunction
+			Method method = getSetArgsMethod(function);
+			method.invoke(function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes));
+			return function;
+		} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+			throw new RuntimeException(e);
+		}
+	}
+
+	static RelDataType invokeGetResultType(
+			Object function, Object[] constantArguments, LogicalType[] argTypes,
+			FlinkTypeFactory typeFactory) {
+		try {
+			// See hive HiveFunction
+			Method method = function.getClass()
+					.getMethod("getHiveResultType", Object[].class, DataType[].class);
+			DataType resultType = (DataType) method.invoke(
+					function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes));
+			return typeFactory.createFieldTypeFromLogicalType(fromDataTypeToLogicalType(resultType));
+		} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+			throw new RuntimeException(e);
+		}
+	}
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java
new file mode 100644
index 0000000..a44576a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.utils;
+
+import org.apache.flink.table.functions.ScalarFunction;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.util.InstantiationUtil;
+
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
+
+import java.io.IOException;
+import java.util.List;
+
+import scala.Some;
+
+import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeGetResultType;
+import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeSetArgs;
+import static org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType;
+
+/**
+ * Hive {@link ScalarSqlFunction}.
+ * Override getFunction to clone function and invoke {@code HiveScalarFunction#setArgumentTypesAndConstants}.
+ * Override SqlReturnTypeInference to invoke {@code HiveScalarFunction#getHiveResultType} instead of
+ * {@code HiveScalarFunction#getResultType(Class[])}.
+ *
+ * @deprecated TODO hack code, its logical should be integrated to ScalarSqlFunction
+ */
+@Deprecated
+public class HiveScalarSqlFunction extends ScalarSqlFunction {
+
+	private final ScalarFunction function;
+
+	public HiveScalarSqlFunction(
+			String name, String displayName,
+			ScalarFunction function, FlinkTypeFactory typeFactory) {
+		super(name, displayName, function, typeFactory, new Some<>(createReturnTypeInference(function, typeFactory)));
+		this.function = function;
+	}
+
+	@Override
+	public ScalarFunction makeFunction(Object[] constantArguments, LogicalType[] argTypes) {
+		ScalarFunction clone;
+		try {
+			clone = InstantiationUtil.clone(function);
+		} catch (IOException | ClassNotFoundException e) {
+			throw new RuntimeException(e);
+		}
+		return (ScalarFunction) invokeSetArgs(clone, constantArguments, argTypes);
+	}
+
+	private static SqlReturnTypeInference createReturnTypeInference(
+			ScalarFunction function, FlinkTypeFactory typeFactory) {
+		return opBinding -> {
+			List<RelDataType> sqlTypes = opBinding.collectOperandTypes();
+			LogicalType[] parameters = UserDefinedFunctionUtils.getOperandTypeArray(opBinding);
+
+			Object[] constantArguments = new Object[sqlTypes.size()];
+			for (int i = 0; i < sqlTypes.size(); i++) {
+				if (!opBinding.isOperandNull(i, false) && opBinding.isOperandLiteral(i, false)) {
+					constantArguments[i] = opBinding.getOperandLiteralValue(
+							i, getDefaultExternalClassForType(parameters[i]));
+				}
+			}
+			return invokeGetResultType(function, constantArguments, parameters, typeFactory);
+		};
+	}
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
index e641708..7c55d73 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen
 
 import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.table.api.TableException
+import org.apache.flink.table.dataformat.DataFormatConverters.{DataFormatConverter, getConverterForDataType}
 import org.apache.flink.table.dataformat._
 import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexAggLocalVariable, RexDistinctKeyVariable}
 import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _}
@@ -30,6 +31,7 @@ import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFu
 import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
 import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
 import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction}
+import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
 import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
 import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
 import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval}
@@ -730,7 +732,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
         GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)
 
       case ssf: ScalarSqlFunction =>
-        new ScalarFunctionCallGen(ssf.getScalarFunction).generate(ctx, operands, resultType)
+        new ScalarFunctionCallGen(
+          ssf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
+            .generate(ctx, operands, resultType)
 
       case tsf: TableSqlFunction =>
         new TableFunctionCallGen(tsf.getTableFunction).generate(ctx, operands, resultType)
@@ -757,4 +761,16 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
         throw new CodeGenException(s"Unsupported call: $explainCall")
     }
   }
+
+  def getOperandLiterals(operands: Seq[GeneratedExpression]): Array[AnyRef] = {
+    operands.map { expr =>
+      expr.literalValue match {
+        case None => null
+        case Some(literal) =>
+          getConverterForDataType(fromLogicalTypeToDataType(expr.resultType))
+              .asInstanceOf[DataFormatConverter[AnyRef, AnyRef]
+              ].toExternal(literal.asInstanceOf[AnyRef])
+      }
+    }.toArray
+  }
 }
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala
index 35b5b5d..159d4f1 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala
@@ -26,6 +26,7 @@ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{
 import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType
 import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
 import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType
+import org.apache.flink.table.types.logical.LogicalType
 
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.sql._
@@ -47,16 +48,18 @@ class ScalarSqlFunction(
     name: String,
     displayName: String,
     scalarFunction: ScalarFunction,
-    typeFactory: FlinkTypeFactory)
+    typeFactory: FlinkTypeFactory,
+    returnTypeInfer: Option[SqlReturnTypeInference] = None)
   extends SqlFunction(
     new SqlIdentifier(name, SqlParserPos.ZERO),
-    createReturnTypeInference(name, scalarFunction, typeFactory),
+    returnTypeInfer.getOrElse(createReturnTypeInference(name, scalarFunction, typeFactory)),
     createOperandTypeInference(name, scalarFunction, typeFactory),
     createOperandTypeChecker(name, scalarFunction),
     null,
     SqlFunctionCategory.USER_DEFINED_FUNCTION) {
 
-  def getScalarFunction: ScalarFunction = scalarFunction
+  def makeFunction(constants: Array[AnyRef], argTypes: Array[LogicalType]): ScalarFunction =
+    scalarFunction
 
   override def isDeterministic: Boolean = scalarFunction.isDeterministic
 
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
index 1de25dd..e565a6e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
@@ -754,6 +754,10 @@ object UserDefinedFunctionUtils {
     }
   }
 
+  def getOperandTypeArray(callBinding: SqlOperatorBinding): Array[LogicalType] = {
+    getOperandType(callBinding).toArray
+  }
+
   def getOperandType(callBinding: SqlOperatorBinding): Seq[LogicalType] = {
     val operandTypes = for (i <- 0 until callBinding.getOperandCount)
       yield callBinding.getOperandType(i)