You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/16 15:46:54 UTC
[25/47] flink git commit: [FLINK-4704] [table] Refactor package
structure of flink-table.
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
new file mode 100644
index 0000000..13fe4c3
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -0,0 +1,1522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen
+
+import java.math.{BigDecimal => JBigDecimal}
+
+import org.apache.calcite.avatica.util.DateTimeUtils
+import org.apache.calcite.rex._
+import org.apache.calcite.sql.SqlOperator
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql.fun.SqlStdOperatorTable._
+import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
+import org.apache.flink.api.common.io.GenericInputFormat
+import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, RowTypeInfo, TupleTypeInfo}
+import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.CodeGenUtils._
+import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
+import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.codegen.calls.FunctionGenerator
+import org.apache.flink.table.codegen.calls.ScalarOperators._
+import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.typeutils.TypeConverter
+import org.apache.flink.table.typeutils.TypeCheckUtils._
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable
+
+/**
+ * A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s.
+ *
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input1 type information about the first input of the Function
+ * @param input2 type information about the second input if the Function is binary
+ * @param input1PojoFieldMapping additional mapping information if input1 is a POJO (POJO types
+ * have no deterministic field order).
+ * @param input2PojoFieldMapping additional mapping information if input2 is a POJO (POJO types
+ * have no deterministic field order).
+ *
+ */
+class CodeGenerator(
+ config: TableConfig,
+ nullableInput: Boolean,
+ input1: TypeInformation[Any],
+ input2: Option[TypeInformation[Any]] = None,
+ input1PojoFieldMapping: Option[Array[Int]] = None,
+ input2PojoFieldMapping: Option[Array[Int]] = None)
+ extends RexVisitor[GeneratedExpression] {
+
+ // check if nullCheck is enabled when inputs can be null
+ if (nullableInput && !config.getNullCheck) {
+ throw new CodeGenException("Null check must be enabled if entire rows can be null.")
+ }
+
+ // check for POJO input1 mapping
+ input1 match {
+ case pt: PojoTypeInfo[_] =>
+ input1PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input1 of type POJO."))
+ case _ => // ok
+ }
+
+ // check for POJO input2 mapping
+ input2 match {
+ case Some(pt: PojoTypeInfo[_]) =>
+ input2PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input2 of type POJO."))
+ case _ => // ok
+ }
+
+ /**
+ * A code generator for generating unary Flink
+ * [[org.apache.flink.api.common.functions.Function]]s with one input.
+ *
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input type information about the input of the Function
+ * @param inputPojoFieldMapping additional mapping information necessary if input is a
+ * POJO (POJO types have no deterministic field order).
+ */
+ def this(
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[Any],
+ inputPojoFieldMapping: Array[Int]) =
+ this(config, nullableInput, input, None, Some(inputPojoFieldMapping))
+
+ /**
+ * A code generator for generating Flink input formats.
+ *
+ * @param config configuration that determines runtime behavior
+ */
+ def this(config: TableConfig) =
+ this(config, false, TypeConverter.DEFAULT_ROW_TYPE, None, None)
+
+ // set of member statements that will be added only once
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusableMemberStatements = mutable.LinkedHashSet[String]()
+
+ // set of constructor statements that will be added only once
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusableInitStatements = mutable.LinkedHashSet[String]()
+
+ // set of statements that will be added only once per record
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusablePerRecordStatements = mutable.LinkedHashSet[String]()
+
+ // map of initial input unboxing expressions that will be added only once
+ // (inputTerm, index) -> expr
+ private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]()
+
+ /**
+ * @return code block of statements that need to be placed in the member area of the Function
+ * (e.g. member variables and their initialization)
+ */
+ def reuseMemberCode(): String = {
+ reusableMemberStatements.mkString("", "\n", "\n")
+ }
+
+ /**
+ * @return code block of statements that need to be placed in the constructor of the Function
+ */
+ def reuseInitCode(): String = {
+ reusableInitStatements.mkString("", "\n", "\n")
+ }
+
+ /**
+ * @return code block of statements that need to be placed in the SAM of the Function
+ */
+ def reusePerRecordCode(): String = {
+ reusablePerRecordStatements.mkString("", "\n", "\n")
+ }
+
+ /**
+ * @return code block of statements that unbox input variables to a primitive variable
+ * and a corresponding null flag variable
+ */
+ def reuseInputUnboxingCode(): String = {
+ reusableInputUnboxingExprs.values.map(_.code).mkString("", "\n", "\n")
+ }
+
+ /**
+ * @return term of the (casted and possibly boxed) first input
+ */
+ var input1Term = "in1"
+
+ /**
+ * @return term of the (casted and possibly boxed) second input
+ */
+ var input2Term = "in2"
+
+ /**
+ * @return term of the (casted) output collector
+ */
+ var collectorTerm = "c"
+
+ /**
+ * @return term of the output record (possibly defined in the member area e.g. Row, Tuple)
+ */
+ var outRecordTerm = "out"
+
+ /**
+ * @return returns if null checking is enabled
+ */
+ def nullCheck: Boolean = config.getNullCheck
+
+ /**
+ * Generates an expression from a RexNode. If objects or variables can be reused, they will be
+ * added to reusable code sections internally.
+ *
+ * @param rex Calcite row expression
+ * @return instance of GeneratedExpression
+ */
+ def generateExpression(rex: RexNode): GeneratedExpression = {
+ rex.accept(this)
+ }
+
+ /**
+ * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java
+ * compiler.
+ *
+ * @param name Class name of the Function. Must not be unique but has to be a valid Java class
+ * identifier.
+ * @param clazz Flink Function to be generated.
+ * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or
+ * output record can be accessed via the given term methods.
+ * @param returnType expected return type
+ * @tparam T Flink Function to be generated.
+ * @return instance of GeneratedFunction
+ */
+ def generateFunction[T <: Function](
+ name: String,
+ clazz: Class[T],
+ bodyCode: String,
+ returnType: TypeInformation[Any])
+ : GeneratedFunction[T] = {
+ val funcName = newName(name)
+
+ // Janino does not support generics, that's why we need
+ // manual casting here
+ val samHeader =
+ // FlatMapFunction
+ if (clazz == classOf[FlatMapFunction[_,_]]) {
+ val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
+ (s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
+ List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
+ }
+
+ // MapFunction
+ else if (clazz == classOf[MapFunction[_,_]]) {
+ val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
+ ("Object map(Object _in1)",
+ List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
+ }
+
+ // FlatJoinFunction
+ else if (clazz == classOf[FlatJoinFunction[_,_,_]]) {
+ val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
+ val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
+ throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
+ (s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
+ List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
+ s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
+ }
+ else {
+ // TODO more functions
+ throw new CodeGenException("Unsupported Function.")
+ }
+
+ val funcCode = j"""
+ public class $funcName
+ implements ${clazz.getCanonicalName} {
+
+ ${reuseMemberCode()}
+
+ public $funcName() throws Exception {
+ ${reuseInitCode()}
+ }
+
+ @Override
+ public ${samHeader._1} throws Exception {
+ ${samHeader._2.mkString("\n")}
+ ${reusePerRecordCode()}
+ ${reuseInputUnboxingCode()}
+ $bodyCode
+ }
+ }
+ """.stripMargin
+
+ GeneratedFunction(funcName, returnType, funcCode)
+ }
+
+ /**
+ * Generates a values input format that can be passed to Java compiler.
+ *
+ * @param name Class name of the input format. Must not be unique but has to be a
+ * valid Java class identifier.
+ * @param records code for creating records
+ * @param returnType expected return type
+ * @tparam T Flink Function to be generated.
+ * @return instance of GeneratedFunction
+ */
+ def generateValuesInputFormat[T](
+ name: String,
+ records: Seq[String],
+ returnType: TypeInformation[Any])
+ : GeneratedFunction[GenericInputFormat[T]] = {
+ val funcName = newName(name)
+
+ addReusableOutRecord(returnType)
+
+ val funcCode = j"""
+ public class $funcName extends ${classOf[GenericInputFormat[_]].getCanonicalName} {
+
+ private int nextIdx = 0;
+
+ ${reuseMemberCode()}
+
+ public $funcName() throws Exception {
+ ${reuseInitCode()}
+ }
+
+ @Override
+ public boolean reachedEnd() throws java.io.IOException {
+ return nextIdx >= ${records.length};
+ }
+
+ @Override
+ public Object nextRecord(Object reuse) {
+ switch (nextIdx) {
+ ${records.zipWithIndex.map { case (r, i) =>
+ s"""
+ |case $i:
+ | $r
+ |break;
+ """.stripMargin
+ }.mkString("\n")}
+ }
+ nextIdx++;
+ return $outRecordTerm;
+ }
+ }
+ """.stripMargin
+
+ GeneratedFunction[GenericInputFormat[T]](funcName, returnType, funcCode)
+ }
+
+ /**
+ * Generates an expression that converts the first input (and second input) into the given type.
+ * If two inputs are converted, the second input is appended. If objects or variables can
+ * be reused, they will be added to reusable code sections internally. The evaluation result
+ * may be stored in the global result variable (see [[outRecordTerm]]).
+ *
+ * @param returnType conversion target type. Inputs and output must have the same arity.
+ * @param resultFieldNames result field names necessary for a mapping to POJO fields.
+ * @return instance of GeneratedExpression
+ */
+ def generateConverterResultExpression(
+ returnType: TypeInformation[_ <: Any],
+ resultFieldNames: Seq[String])
+ : GeneratedExpression = {
+ val input1AccessExprs = for (i <- 0 until input1.getArity)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
+
+ val input2AccessExprs = input2 match {
+ case Some(ti) => for (i <- 0 until ti.getArity)
+ yield generateInputAccess(ti, input2Term, i, input2PojoFieldMapping)
+ case None => Seq() // add nothing
+ }
+
+ generateResultExpression(input1AccessExprs ++ input2AccessExprs, returnType, resultFieldNames)
+ }
+
+ /**
+ * Generates an expression from the left input and the right table function.
+ */
+ def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = {
+ val input1AccessExprs = for (i <- 0 until input1.getArity)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
+
+ val input2AccessExprs = input2 match {
+ case Some(ti) => for (i <- 0 until ti.getArity)
+ // use generateFieldAccess instead of generateInputAccess to avoid the generated table
+ // function's field access code is put on the top of function body rather than
+ // the while loop
+ yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping)
+ case None => throw new CodeGenException("Type information of input2 must not be null.")
+ }
+ (input1AccessExprs, input2AccessExprs)
+ }
+
+ /**
+ * Generates an expression from a sequence of RexNode. If objects or variables can be reused,
+ * they will be added to reusable code sections internally. The evaluation result
+ * may be stored in the global result variable (see [[outRecordTerm]]).
+ *
+ * @param returnType conversion target type. Type must have the same arity than rexNodes.
+ * @param resultFieldNames result field names necessary for a mapping to POJO fields.
+ * @param rexNodes sequence of RexNode to be converted
+ * @return instance of GeneratedExpression
+ */
+ def generateResultExpression(
+ returnType: TypeInformation[_ <: Any],
+ resultFieldNames: Seq[String],
+ rexNodes: Seq[RexNode])
+ : GeneratedExpression = {
+ val fieldExprs = rexNodes.map(generateExpression)
+ generateResultExpression(fieldExprs, returnType, resultFieldNames)
+ }
+
+ /**
+ * Generates an expression from a sequence of other expressions. If objects or variables can
+ * be reused, they will be added to reusable code sections internally. The evaluation result
+ * may be stored in the global result variable (see [[outRecordTerm]]).
+ *
+ * @param fieldExprs field expressions to be converted
+ * @param returnType conversion target type. Type must have the same arity than fieldExprs.
+ * @param resultFieldNames result field names necessary for a mapping to POJO fields.
+ * @return instance of GeneratedExpression
+ */
+ def generateResultExpression(
+ fieldExprs: Seq[GeneratedExpression],
+ returnType: TypeInformation[_ <: Any],
+ resultFieldNames: Seq[String])
+ : GeneratedExpression = {
+ // initial type check
+ if (returnType.getArity != fieldExprs.length) {
+ throw new CodeGenException("Arity of result type does not match number of expressions.")
+ }
+ if (resultFieldNames.length != fieldExprs.length) {
+ throw new CodeGenException("Arity of result field names does not match number of " +
+ "expressions.")
+ }
+ // type check
+ returnType match {
+ case pt: PojoTypeInfo[_] =>
+ fieldExprs.zipWithIndex foreach {
+ case (fieldExpr, i) if fieldExpr.resultType != pt.getTypeAt(resultFieldNames(i)) =>
+ throw new CodeGenException("Incompatible types of expression and result type.")
+
+ case _ => // ok
+ }
+
+ case ct: CompositeType[_] =>
+ fieldExprs.zipWithIndex foreach {
+ case (fieldExpr, i) if fieldExpr.resultType != ct.getTypeAt(i) =>
+ throw new CodeGenException("Incompatible types of expression and result type.")
+ case _ => // ok
+ }
+
+ case at: AtomicType[_] if at != fieldExprs.head.resultType =>
+ throw new CodeGenException("Incompatible types of expression and result type.")
+
+ case _ => // ok
+ }
+
+ val returnTypeTerm = boxedTypeTermForTypeInfo(returnType)
+ val boxedFieldExprs = fieldExprs.map(generateOutputFieldBoxing)
+
+ // generate result expression
+ returnType match {
+ case ri: RowTypeInfo =>
+ addReusableOutRecord(ri)
+ val resultSetters: String = boxedFieldExprs.zipWithIndex map {
+ case (fieldExpr, i) =>
+ if (nullCheck) {
+ s"""
+ |${fieldExpr.code}
+ |if (${fieldExpr.nullTerm}) {
+ | $outRecordTerm.setField($i, null);
+ |}
+ |else {
+ | $outRecordTerm.setField($i, ${fieldExpr.resultTerm});
+ |}
+ |""".stripMargin
+ }
+ else {
+ s"""
+ |${fieldExpr.code}
+ |$outRecordTerm.setField($i, ${fieldExpr.resultTerm});
+ |""".stripMargin
+ }
+ } mkString "\n"
+
+ GeneratedExpression(outRecordTerm, "false", resultSetters, returnType)
+
+ case pt: PojoTypeInfo[_] =>
+ addReusableOutRecord(pt)
+ val resultSetters: String = boxedFieldExprs.zip(resultFieldNames) map {
+ case (fieldExpr, fieldName) =>
+ val accessor = getFieldAccessor(pt.getTypeClass, fieldName)
+
+ accessor match {
+ // Reflective access of primitives/Objects
+ case ObjectPrivateFieldAccessor(field) =>
+ val fieldTerm = addReusablePrivateFieldAccess(pt.getTypeClass, fieldName)
+
+ val defaultIfNull = if (isFieldPrimitive(field)) {
+ primitiveDefaultValue(fieldExpr.resultType)
+ } else {
+ "null"
+ }
+
+ if (nullCheck) {
+ s"""
+ |${fieldExpr.code}
+ |if (${fieldExpr.nullTerm}) {
+ | ${reflectiveFieldWriteAccess(
+ fieldTerm,
+ field,
+ outRecordTerm,
+ defaultIfNull)};
+ |}
+ |else {
+ | ${reflectiveFieldWriteAccess(
+ fieldTerm,
+ field,
+ outRecordTerm,
+ fieldExpr.resultTerm)};
+ |}
+ |""".stripMargin
+ }
+ else {
+ s"""
+ |${fieldExpr.code}
+ |${reflectiveFieldWriteAccess(
+ fieldTerm,
+ field,
+ outRecordTerm,
+ fieldExpr.resultTerm)};
+ |""".stripMargin
+ }
+
+ // primitive or Object field access (implicit boxing)
+ case _ =>
+ if (nullCheck) {
+ s"""
+ |${fieldExpr.code}
+ |if (${fieldExpr.nullTerm}) {
+ | $outRecordTerm.$fieldName = null;
+ |}
+ |else {
+ | $outRecordTerm.$fieldName = ${fieldExpr.resultTerm};
+ |}
+ |""".stripMargin
+ }
+ else {
+ s"""
+ |${fieldExpr.code}
+ |$outRecordTerm.$fieldName = ${fieldExpr.resultTerm};
+ |""".stripMargin
+ }
+ }
+ } mkString "\n"
+
+ GeneratedExpression(outRecordTerm, "false", resultSetters, returnType)
+
+ case tup: TupleTypeInfo[_] =>
+ addReusableOutRecord(tup)
+ val resultSetters: String = boxedFieldExprs.zipWithIndex map {
+ case (fieldExpr, i) =>
+ val fieldName = "f" + i
+ if (nullCheck) {
+ s"""
+ |${fieldExpr.code}
+ |if (${fieldExpr.nullTerm}) {
+ | throw new NullPointerException("Null result cannot be stored in a Tuple.");
+ |}
+ |else {
+ | $outRecordTerm.$fieldName = ${fieldExpr.resultTerm};
+ |}
+ |""".stripMargin
+ }
+ else {
+ s"""
+ |${fieldExpr.code}
+ |$outRecordTerm.$fieldName = ${fieldExpr.resultTerm};
+ |""".stripMargin
+ }
+ } mkString "\n"
+
+ GeneratedExpression(outRecordTerm, "false", resultSetters, returnType)
+
+ case cc: CaseClassTypeInfo[_] =>
+ val fieldCodes: String = boxedFieldExprs.map(_.code).mkString("\n")
+ val constructorParams: String = boxedFieldExprs.map(_.resultTerm).mkString(", ")
+ val resultTerm = newName(outRecordTerm)
+
+ val nullCheckCode = if (nullCheck) {
+ boxedFieldExprs map { (fieldExpr) =>
+ s"""
+ |if (${fieldExpr.nullTerm}) {
+ | throw new NullPointerException("Null result cannot be stored in a Case Class.");
+ |}
+ |""".stripMargin
+ } mkString "\n"
+ } else {
+ ""
+ }
+
+ val resultCode =
+ s"""
+ |$fieldCodes
+ |$nullCheckCode
+ |$returnTypeTerm $resultTerm = new $returnTypeTerm($constructorParams);
+ |""".stripMargin
+
+ GeneratedExpression(resultTerm, "false", resultCode, returnType)
+
+ case a: AtomicType[_] =>
+ val fieldExpr = boxedFieldExprs.head
+ val nullCheckCode = if (nullCheck) {
+ s"""
+ |if (${fieldExpr.nullTerm}) {
+ | throw new NullPointerException("Null result cannot be used for atomic types.");
+ |}
+ |""".stripMargin
+ } else {
+ ""
+ }
+ val resultCode =
+ s"""
+ |${fieldExpr.code}
+ |$nullCheckCode
+ |""".stripMargin
+
+ GeneratedExpression(fieldExpr.resultTerm, "false", resultCode, returnType)
+
+ case _ =>
+ throw new CodeGenException(s"Unsupported result type: $returnType")
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // RexVisitor methods
+ // ----------------------------------------------------------------------------------------------
+
+ override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
+ // if inputRef index is within size of input1 we work with input1, input2 otherwise
+ val input = if (inputRef.getIndex < input1.getArity) {
+ (input1, input1Term, input1PojoFieldMapping)
+ } else {
+ (input2.getOrElse(throw new CodeGenException("Invalid input access.")),
+ input2Term,
+ input2PojoFieldMapping)
+ }
+
+ val index = if (input._2 == input1Term) {
+ inputRef.getIndex
+ } else {
+ inputRef.getIndex - input1.getArity
+ }
+
+ generateInputAccess(input._1, input._2, index, input._3)
+ }
+
+ override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = {
+ val refExpr = rexFieldAccess.getReferenceExpr.accept(this)
+ val index = rexFieldAccess.getField.getIndex
+ val fieldAccessExpr = generateFieldAccess(
+ refExpr.resultType,
+ refExpr.resultTerm,
+ index,
+ input1PojoFieldMapping)
+
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldAccessExpr.resultType)
+ val defaultValue = primitiveDefaultValue(fieldAccessExpr.resultType)
+ val resultCode = if (nullCheck) {
+ s"""
+ |${refExpr.code}
+ |$resultTypeTerm $resultTerm;
+ |boolean $nullTerm;
+ |if (${refExpr.nullTerm}) {
+ | $resultTerm = $defaultValue;
+ | $nullTerm = true;
+ |}
+ |else {
+ | ${fieldAccessExpr.code}
+ | $resultTerm = ${fieldAccessExpr.resultTerm};
+ | $nullTerm = ${fieldAccessExpr.nullTerm};
+ |}
+ |""".stripMargin
+ } else {
+ s"""
+ |${refExpr.code}
+ |${fieldAccessExpr.code}
+ |$resultTypeTerm $resultTerm = ${fieldAccessExpr.resultTerm};
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType)
+ }
+
+ override def visitLiteral(literal: RexLiteral): GeneratedExpression = {
+ val resultType = FlinkTypeFactory.toTypeInfo(literal.getType)
+ val value = literal.getValue3
+ // null value with type
+ if (value == null) {
+ return generateNullLiteral(resultType)
+ }
+ // non-null values
+ literal.getType.getSqlTypeName match {
+
+ case BOOLEAN =>
+ generateNonNullLiteral(resultType, literal.getValue3.toString)
+
+ case TINYINT =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidByte) {
+ generateNonNullLiteral(resultType, decimal.byteValue().toString)
+ }
+ else {
+ throw new CodeGenException("Decimal can not be converted to byte.")
+ }
+
+ case SMALLINT =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidShort) {
+ generateNonNullLiteral(resultType, decimal.shortValue().toString)
+ }
+ else {
+ throw new CodeGenException("Decimal can not be converted to short.")
+ }
+
+ case INTEGER =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidInt) {
+ generateNonNullLiteral(resultType, decimal.intValue().toString)
+ }
+ else {
+ throw new CodeGenException("Decimal can not be converted to integer.")
+ }
+
+ case BIGINT =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidLong) {
+ generateNonNullLiteral(resultType, decimal.longValue().toString + "L")
+ }
+ else {
+ throw new CodeGenException("Decimal can not be converted to long.")
+ }
+
+ case FLOAT =>
+ val floatValue = value.asInstanceOf[JBigDecimal].floatValue()
+ floatValue match {
+ case Float.NaN => generateNonNullLiteral(resultType, "java.lang.Float.NaN")
+ case Float.NegativeInfinity =>
+ generateNonNullLiteral(resultType, "java.lang.Float.NEGATIVE_INFINITY")
+ case Float.PositiveInfinity =>
+ generateNonNullLiteral(resultType, "java.lang.Float.POSITIVE_INFINITY")
+ case _ => generateNonNullLiteral(resultType, floatValue.toString + "f")
+ }
+
+ case DOUBLE =>
+ val doubleValue = value.asInstanceOf[JBigDecimal].doubleValue()
+ doubleValue match {
+ case Double.NaN => generateNonNullLiteral(resultType, "java.lang.Double.NaN")
+ case Double.NegativeInfinity =>
+ generateNonNullLiteral(resultType, "java.lang.Double.NEGATIVE_INFINITY")
+ case Double.PositiveInfinity =>
+ generateNonNullLiteral(resultType, "java.lang.Double.POSITIVE_INFINITY")
+ case _ => generateNonNullLiteral(resultType, doubleValue.toString + "d")
+ }
+ case DECIMAL =>
+ val decimalField = addReusableDecimal(value.asInstanceOf[JBigDecimal])
+ generateNonNullLiteral(resultType, decimalField)
+
+ case VARCHAR | CHAR =>
+ generateNonNullLiteral(resultType, "\"" + value.toString + "\"")
+
+ case SYMBOL =>
+ generateSymbol(value.asInstanceOf[Enum[_]])
+
+ case DATE =>
+ generateNonNullLiteral(resultType, value.toString)
+
+ case TIME =>
+ generateNonNullLiteral(resultType, value.toString)
+
+ case TIMESTAMP =>
+ generateNonNullLiteral(resultType, value.toString + "L")
+
+ case typeName if YEAR_INTERVAL_TYPES.contains(typeName) =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidInt) {
+ generateNonNullLiteral(resultType, decimal.intValue().toString)
+ } else {
+ throw new CodeGenException("Decimal can not be converted to interval of months.")
+ }
+
+ case typeName if DAY_INTERVAL_TYPES.contains(typeName) =>
+ val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
+ if (decimal.isValidLong) {
+ generateNonNullLiteral(resultType, decimal.longValue().toString + "L")
+ } else {
+ throw new CodeGenException("Decimal can not be converted to interval of milliseconds.")
+ }
+
+ case t@_ =>
+ throw new CodeGenException(s"Type not supported: $t")
+ }
+ }
+
+ override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
+ GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1)
+ }
+
+ override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
+ throw new CodeGenException("Local variables are not supported yet.")
+
+ override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression =
+ throw new CodeGenException("Range references are not supported yet.")
+
+ override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression =
+ throw new CodeGenException("Dynamic parameter references are not supported yet.")
+
+ override def visitCall(call: RexCall): GeneratedExpression = {
+ val operands = call.getOperands.map(_.accept(this))
+ val resultType = FlinkTypeFactory.toTypeInfo(call.getType)
+
+ call.getOperator match {
+ // arithmetic
+ case PLUS if isNumeric(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireNumeric(left)
+ requireNumeric(right)
+ generateArithmeticOperator("+", nullCheck, resultType, left, right)
+
+ case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireTemporal(left)
+ requireTemporal(right)
+ generateTemporalPlusMinus(plus = true, nullCheck, left, right)
+
+ case MINUS if isNumeric(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireNumeric(left)
+ requireNumeric(right)
+ generateArithmeticOperator("-", nullCheck, resultType, left, right)
+
+ case MINUS | MINUS_DATE if isTemporal(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireTemporal(left)
+ requireTemporal(right)
+ generateTemporalPlusMinus(plus = false, nullCheck, left, right)
+
+ case MULTIPLY if isNumeric(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireNumeric(left)
+ requireNumeric(right)
+ generateArithmeticOperator("*", nullCheck, resultType, left, right)
+
+ case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireNumeric(left)
+ requireNumeric(right)
+ generateArithmeticOperator("/", nullCheck, resultType, left, right)
+
+ case MOD if isNumeric(resultType) =>
+ val left = operands.head
+ val right = operands(1)
+ requireNumeric(left)
+ requireNumeric(right)
+ generateArithmeticOperator("%", nullCheck, resultType, left, right)
+
+ case UNARY_MINUS if isNumeric(resultType) =>
+ val operand = operands.head
+ requireNumeric(operand)
+ generateUnaryArithmeticOperator("-", nullCheck, resultType, operand)
+
+ case UNARY_MINUS if isTimeInterval(resultType) =>
+ val operand = operands.head
+ requireTimeInterval(operand)
+ generateUnaryIntervalPlusMinus(plus = false, nullCheck, operand)
+
+ case UNARY_PLUS if isNumeric(resultType) =>
+ val operand = operands.head
+ requireNumeric(operand)
+ generateUnaryArithmeticOperator("+", nullCheck, resultType, operand)
+
+ case UNARY_PLUS if isTimeInterval(resultType) =>
+ val operand = operands.head
+ requireTimeInterval(operand)
+ generateUnaryIntervalPlusMinus(plus = true, nullCheck, operand)
+
+ // comparison
+ case EQUALS =>
+ val left = operands.head
+ val right = operands(1)
+ generateEquals(nullCheck, left, right)
+
+ case NOT_EQUALS =>
+ val left = operands.head
+ val right = operands(1)
+ generateNotEquals(nullCheck, left, right)
+
+ case GREATER_THAN =>
+ val left = operands.head
+ val right = operands(1)
+ requireComparable(left)
+ requireComparable(right)
+ generateComparison(">", nullCheck, left, right)
+
+ case GREATER_THAN_OR_EQUAL =>
+ val left = operands.head
+ val right = operands(1)
+ requireComparable(left)
+ requireComparable(right)
+ generateComparison(">=", nullCheck, left, right)
+
+ case LESS_THAN =>
+ val left = operands.head
+ val right = operands(1)
+ requireComparable(left)
+ requireComparable(right)
+ generateComparison("<", nullCheck, left, right)
+
+ case LESS_THAN_OR_EQUAL =>
+ val left = operands.head
+ val right = operands(1)
+ requireComparable(left)
+ requireComparable(right)
+ generateComparison("<=", nullCheck, left, right)
+
+ case IS_NULL =>
+ val operand = operands.head
+ generateIsNull(nullCheck, operand)
+
+ case IS_NOT_NULL =>
+ val operand = operands.head
+ generateIsNotNull(nullCheck, operand)
+
+ // logic
+ case AND =>
+ operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
+ requireBoolean(left)
+ requireBoolean(right)
+ generateAnd(nullCheck, left, right)
+ }
+
+ case OR =>
+ operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
+ requireBoolean(left)
+ requireBoolean(right)
+ generateOr(nullCheck, left, right)
+ }
+
+ case NOT =>
+ val operand = operands.head
+ requireBoolean(operand)
+ generateNot(nullCheck, operand)
+
+ case CASE =>
+ generateIfElse(nullCheck, operands, resultType)
+
+ case IS_TRUE =>
+ val operand = operands.head
+ requireBoolean(operand)
+ generateIsTrue(operand)
+
+ case IS_NOT_TRUE =>
+ val operand = operands.head
+ requireBoolean(operand)
+ generateIsNotTrue(operand)
+
+ case IS_FALSE =>
+ val operand = operands.head
+ requireBoolean(operand)
+ generateIsFalse(operand)
+
+ case IS_NOT_FALSE =>
+ val operand = operands.head
+ requireBoolean(operand)
+ generateIsNotFalse(operand)
+
+ // casting
+ case CAST | REINTERPRET =>
+ val operand = operands.head
+ generateCast(nullCheck, operand, resultType)
+
+ // as / renaming
+ case AS =>
+ operands.head
+
+ // string arithmetic
+ case CONCAT =>
+ val left = operands.head
+ val right = operands(1)
+ requireString(left)
+ generateArithmeticOperator("+", nullCheck, resultType, left, right)
+
+ // arrays
+ case ARRAY_VALUE_CONSTRUCTOR =>
+ generateArray(this, resultType, operands)
+
+ case ITEM =>
+ val array = operands.head
+ val index = operands(1)
+ requireArray(array)
+ requireInteger(index)
+ generateArrayElementAt(this, array, index)
+
+ case CARDINALITY =>
+ val array = operands.head
+ requireArray(array)
+ generateArrayCardinality(nullCheck, array)
+
+ case ELEMENT =>
+ val array = operands.head
+ requireArray(array)
+ generateArrayElement(this, array)
+
+ // advanced scalar functions
+ case sqlOperator: SqlOperator =>
+ val callGen = FunctionGenerator.getCallGenerator(
+ sqlOperator,
+ operands.map(_.resultType),
+ resultType)
+ callGen
+ .getOrElse(throw new CodeGenException(s"Unsupported call: $sqlOperator \n" +
+ s"If you think this function should be supported, " +
+ s"you can create an issue and start a discussion for it."))
+ .generate(this, operands)
+
+ // unknown or invalid
+ case call@_ =>
+ throw new CodeGenException(s"Unsupported call: $call")
+ }
+ }
+
+ override def visitOver(over: RexOver): GeneratedExpression =
+ throw new CodeGenException("Aggregate functions over windows are not supported yet.")
+
+ override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression =
+ throw new CodeGenException("Subqueries are not supported yet.")
+
+ // ----------------------------------------------------------------------------------------------
+ // generator helping methods
+ // ----------------------------------------------------------------------------------------------
+
+ private def generateInputAccess(
+ inputType: TypeInformation[Any],
+ inputTerm: String,
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
+ : GeneratedExpression = {
+ // if input has been used before, we can reuse the code that
+ // has already been generated
+ val inputExpr = reusableInputUnboxingExprs.get((inputTerm, index)) match {
+ // input access and unboxing has already been generated
+ case Some(expr) =>
+ expr
+
+ // generate input access and unboxing if necessary
+ case None =>
+ val expr = if (nullableInput) {
+ generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
+ } else {
+ generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
+ }
+
+ reusableInputUnboxingExprs((inputTerm, index)) = expr
+ expr
+ }
+ // hide the generated code as it will be executed only once
+ GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType)
+ }
+
+ private def generateNullableInputFieldAccess(
+ inputType: TypeInformation[Any],
+ inputTerm: String,
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
+ : GeneratedExpression = {
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+
+ val fieldType = inputType match {
+ case ct: CompositeType[_] =>
+ val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
+ pojoFieldMapping.get(index)
+ }
+ else {
+ index
+ }
+ ct.getTypeAt(fieldIndex)
+ case at: AtomicType[_] => at
+ case _ => throw new CodeGenException("Unsupported type for input field access.")
+ }
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
+ val defaultValue = primitiveDefaultValue(fieldType)
+ val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
+
+ val inputCheckCode =
+ s"""
+ |$resultTypeTerm $resultTerm;
+ |boolean $nullTerm;
+ |if ($inputTerm == null) {
+ | $resultTerm = $defaultValue;
+ | $nullTerm = true;
+ |}
+ |else {
+ | ${fieldAccessExpr.code}
+ | $resultTerm = ${fieldAccessExpr.resultTerm};
+ | $nullTerm = ${fieldAccessExpr.nullTerm};
+ |}
+ |""".stripMargin
+
+ GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
+ }
+
+ private def generateFieldAccess(
+ inputType: TypeInformation[_],
+ inputTerm: String,
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
+ : GeneratedExpression = {
+ inputType match {
+ case ct: CompositeType[_] =>
+ val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && pojoFieldMapping.nonEmpty) {
+ pojoFieldMapping.get(index)
+ }
+ else {
+ index
+ }
+ val accessor = fieldAccessorFor(ct, fieldIndex)
+ val fieldType: TypeInformation[Any] = ct.getTypeAt(fieldIndex)
+ val fieldTypeTerm = boxedTypeTermForTypeInfo(fieldType)
+
+ accessor match {
+ case ObjectFieldAccessor(field) =>
+ // primitive
+ if (isFieldPrimitive(field)) {
+ generateNonNullLiteral(fieldType, s"$inputTerm.${field.getName}")
+ }
+ // Object
+ else {
+ generateInputFieldUnboxing(
+ fieldType,
+ s"($fieldTypeTerm) $inputTerm.${field.getName}")
+ }
+
+ case ObjectGenericFieldAccessor(fieldName) =>
+ // Object
+ val inputCode = s"($fieldTypeTerm) $inputTerm.$fieldName"
+ generateInputFieldUnboxing(fieldType, inputCode)
+
+ case ObjectMethodAccessor(methodName) =>
+ // Object
+ val inputCode = s"($fieldTypeTerm) $inputTerm.$methodName()"
+ generateInputFieldUnboxing(fieldType, inputCode)
+
+ case ProductAccessor(i) =>
+ // Object
+ val inputCode = s"($fieldTypeTerm) $inputTerm.getField($i)"
+ generateInputFieldUnboxing(fieldType, inputCode)
+
+ case ObjectPrivateFieldAccessor(field) =>
+ val fieldTerm = addReusablePrivateFieldAccess(ct.getTypeClass, field.getName)
+ val reflectiveAccessCode = reflectiveFieldReadAccess(fieldTerm, field, inputTerm)
+ // primitive
+ if (isFieldPrimitive(field)) {
+ generateNonNullLiteral(fieldType, reflectiveAccessCode)
+ }
+ // Object
+ else {
+ generateInputFieldUnboxing(fieldType, reflectiveAccessCode)
+ }
+ }
+
+ case at: AtomicType[_] =>
+ val fieldTypeTerm = boxedTypeTermForTypeInfo(at)
+ val inputCode = s"($fieldTypeTerm) $inputTerm"
+ generateInputFieldUnboxing(at, inputCode)
+
+ case _ =>
+ throw new CodeGenException("Unsupported type for input field access.")
+ }
+ }
+
+ private def generateNullLiteral(resultType: TypeInformation[_]): GeneratedExpression = {
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
+ val defaultValue = primitiveDefaultValue(resultType)
+
+ if (nullCheck) {
+ val wrappedCode = s"""
+ |$resultTypeTerm $resultTerm = $defaultValue;
+ |boolean $nullTerm = true;
+ |""".stripMargin
+ GeneratedExpression(resultTerm, nullTerm, wrappedCode, resultType)
+ } else {
+ throw new CodeGenException("Null literals are not allowed if nullCheck is disabled.")
+ }
+ }
+
+ private[flink] def generateNonNullLiteral(
+ literalType: TypeInformation[_],
+ literalCode: String)
+ : GeneratedExpression = {
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(literalType)
+
+ val resultCode = if (nullCheck) {
+ s"""
+ |$resultTypeTerm $resultTerm = $literalCode;
+ |boolean $nullTerm = false;
+ |""".stripMargin
+ } else {
+ s"""
+ |$resultTypeTerm $resultTerm = $literalCode;
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, nullTerm, resultCode, literalType)
+ }
+
+ private[flink] def generateSymbol(enum: Enum[_]): GeneratedExpression = {
+ GeneratedExpression(
+ qualifyEnum(enum),
+ "false",
+ "",
+ new GenericTypeInfo(enum.getDeclaringClass))
+ }
+
+ /**
+ * Converts the external boxed format to an internal mostly primitive field representation.
+ * Wrapper types can autoboxed to their corresponding primitive type (Integer -> int). External
+ * objects are converted to their internal representation (Timestamp -> internal timestamp
+ * in long).
+ *
+ * @param fieldType type of field
+ * @param fieldTerm expression term of field to be unboxed
+ * @return internal unboxed field representation
+ */
+ private[flink] def generateInputFieldUnboxing(
+ fieldType: TypeInformation[_],
+ fieldTerm: String)
+ : GeneratedExpression = {
+ val tmpTerm = newName("tmp")
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+ val tmpTypeTerm = boxedTypeTermForTypeInfo(fieldType)
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
+ val defaultValue = primitiveDefaultValue(fieldType)
+
+ // explicit unboxing
+ val unboxedFieldCode = if (isTimePoint(fieldType)) {
+ timePointToInternalCode(fieldType, fieldTerm)
+ } else {
+ fieldTerm
+ }
+
+ val wrappedCode = if (nullCheck && !isReference(fieldType)) {
+ s"""
+ |$tmpTypeTerm $tmpTerm = $unboxedFieldCode;
+ |boolean $nullTerm = $tmpTerm == null;
+ |$resultTypeTerm $resultTerm;
+ |if ($nullTerm) {
+ | $resultTerm = $defaultValue;
+ |}
+ |else {
+ | $resultTerm = $tmpTerm;
+ |}
+ |""".stripMargin
+ } else if (nullCheck) {
+ s"""
+ |$resultTypeTerm $resultTerm = $unboxedFieldCode;
+ |boolean $nullTerm = $fieldTerm == null;
+ |""".stripMargin
+ } else {
+ s"""
+ |$resultTypeTerm $resultTerm = $unboxedFieldCode;
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType)
+ }
+
+ /**
+ * Converts the internal mostly primitive field representation to an external boxed format.
+ * Primitive types can autoboxed to their corresponding object type (int -> Integer). Internal
+ * representations are converted to their external objects (internal timestamp
+ * in long -> Timestamp).
+ *
+ * @param expr expression to be boxed
+ * @return external boxed field representation
+ */
+ private[flink] def generateOutputFieldBoxing(expr: GeneratedExpression): GeneratedExpression = {
+ expr.resultType match {
+ // convert internal date/time/timestamp to java.sql.* objects
+ case SqlTimeTypeInfo.DATE | SqlTimeTypeInfo.TIME | SqlTimeTypeInfo.TIMESTAMP =>
+ val resultTerm = newName("result")
+ val resultTypeTerm = boxedTypeTermForTypeInfo(expr.resultType)
+ val convMethod = internalToTimePointCode(expr.resultType, expr.resultTerm)
+
+ val resultCode = if (nullCheck) {
+ s"""
+ |${expr.code}
+ |$resultTypeTerm $resultTerm;
+ |if (${expr.nullTerm}) {
+ | $resultTerm = null;
+ |}
+ |else {
+ | $resultTerm = $convMethod;
+ |}
+ |""".stripMargin
+ } else {
+ s"""
+ |${expr.code}
+ |$resultTypeTerm $resultTerm = $convMethod;
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, expr.nullTerm, resultCode, expr.resultType)
+
+ // other types are autoboxed or need no boxing
+ case _ => expr
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Reusable code snippets
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Adds a reusable output record to the member area of the generated [[Function]].
+ * The passed [[TypeInformation]] defines the type class to be instantiated.
+ *
+ * @param ti type information of type class to be instantiated during runtime
+ * @return member variable term
+ */
+ def addReusableOutRecord(ti: TypeInformation[_]): Unit = {
+ val statement = ti match {
+ case rt: RowTypeInfo =>
+ s"""
+ |transient ${ti.getTypeClass.getCanonicalName} $outRecordTerm =
+ | new ${ti.getTypeClass.getCanonicalName}(${rt.getArity});
+ |""".stripMargin
+ case _ =>
+ s"""
+ |${ti.getTypeClass.getCanonicalName} $outRecordTerm =
+ | new ${ti.getTypeClass.getCanonicalName}();
+ |""".stripMargin
+ }
+ reusableMemberStatements.add(statement)
+ }
+
+ /**
+ * Adds a reusable [[java.lang.reflect.Field]] to the member area of the generated [[Function]].
+ * The field can be used for accessing POJO fields more efficiently during runtime, however,
+ * the field does not have to be public.
+ *
+ * @param clazz class of containing field
+ * @param fieldName name of field to be extracted and instantiated during runtime
+ * @return member variable term
+ */
+ def addReusablePrivateFieldAccess(clazz: Class[_], fieldName: String): String = {
+ val fieldTerm = s"field_${clazz.getCanonicalName.replace('.', '$')}_$fieldName"
+ val fieldExtraction =
+ s"""
+ |transient java.lang.reflect.Field $fieldTerm =
+ | org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField(
+ | ${clazz.getCanonicalName}.class, "$fieldName");
+ |""".stripMargin
+ reusableMemberStatements.add(fieldExtraction)
+
+ val fieldAccessibility =
+ s"""
+ |$fieldTerm.setAccessible(true);
+ |""".stripMargin
+ reusableInitStatements.add(fieldAccessibility)
+
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable [[java.math.BigDecimal]] to the member area of the generated [[Function]].
+ *
+ * @param decimal decimal object to be instantiated during runtime
+ * @return member variable term
+ */
+ def addReusableDecimal(decimal: JBigDecimal): String = decimal match {
+ case JBigDecimal.ZERO => "java.math.BigDecimal.ZERO"
+ case JBigDecimal.ONE => "java.math.BigDecimal.ONE"
+ case JBigDecimal.TEN => "java.math.BigDecimal.TEN"
+ case _ =>
+ val fieldTerm = newName("decimal")
+ val fieldDecimal =
+ s"""
+ |transient java.math.BigDecimal $fieldTerm =
+ | new java.math.BigDecimal("${decimal.toString}");
+ |""".stripMargin
+ reusableMemberStatements.add(fieldDecimal)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]].
+ * The [[UserDefinedFunction]] must have a default constructor, however, it does not have
+ * to be public.
+ *
+ * @param function [[UserDefinedFunction]] object to be instantiated during runtime
+ * @return member variable term
+ */
+ def addReusableFunction(function: UserDefinedFunction): String = {
+ val classQualifier = function.getClass.getCanonicalName
+ val fieldTerm = s"function_${classQualifier.replace('.', '$')}"
+
+ val fieldFunction =
+ s"""
+ |transient $classQualifier $fieldTerm = null;
+ |""".stripMargin
+ reusableMemberStatements.add(fieldFunction)
+
+ val constructorTerm = s"constructor_${classQualifier.replace('.', '$')}"
+ val constructorAccessibility =
+ s"""
+ |java.lang.reflect.Constructor $constructorTerm =
+ | $classQualifier.class.getDeclaredConstructor();
+ |$constructorTerm.setAccessible(true);
+ |$fieldTerm = ($classQualifier) $constructorTerm.newInstance();
+ """.stripMargin
+ reusableInitStatements.add(constructorAccessibility)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable array to the member area of the generated [[Function]].
+ */
+ def addReusableArray(clazz: Class[_], size: Int): String = {
+ val fieldTerm = newName("array")
+ val classQualifier = clazz.getCanonicalName // works also for int[] etc.
+ val initArray = classQualifier.replaceFirst("\\[", s"[$size")
+ val fieldArray =
+ s"""
+ |transient $classQualifier $fieldTerm =
+ | new $initArray;
+ |""".stripMargin
+ reusableMemberStatements.add(fieldArray)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable timestamp to the beginning of the SAM of the generated [[Function]].
+ */
+ def addReusableTimestamp(): String = {
+ val fieldTerm = s"timestamp"
+
+ val field =
+ s"""
+ |final long $fieldTerm = java.lang.System.currentTimeMillis();
+ |""".stripMargin
+ reusablePerRecordStatements.add(field)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable local timestamp to the beginning of the SAM of the generated [[Function]].
+ */
+ def addReusableLocalTimestamp(): String = {
+ val fieldTerm = s"localtimestamp"
+
+ val timestamp = addReusableTimestamp()
+
+ val field =
+ s"""
+ |final long $fieldTerm = $timestamp + java.util.TimeZone.getDefault().getOffset(timestamp);
+ |""".stripMargin
+ reusablePerRecordStatements.add(field)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable time to the beginning of the SAM of the generated [[Function]].
+ */
+ def addReusableTime(): String = {
+ val fieldTerm = s"time"
+
+ val timestamp = addReusableTimestamp()
+
+ // adopted from org.apache.calcite.runtime.SqlFunctions.currentTime()
+ val field =
+ s"""
+ |final int $fieldTerm = (int) ($timestamp % ${DateTimeUtils.MILLIS_PER_DAY});
+ |if (time < 0) {
+ | time += ${DateTimeUtils.MILLIS_PER_DAY};
+ |}
+ |""".stripMargin
+ reusablePerRecordStatements.add(field)
+ fieldTerm
+ }
+
+ /**
+ * Adds a reusable local time to the beginning of the SAM of the generated [[Function]].
+ */
+ def addReusableLocalTime(): String = {
+ val fieldTerm = s"localtime"
+
+ val localtimestamp = addReusableLocalTimestamp()
+
+ // adopted from org.apache.calcite.runtime.SqlFunctions.localTime()
+ val field =
+ s"""
+ |final int $fieldTerm = (int) ($localtimestamp % ${DateTimeUtils.MILLIS_PER_DAY});
+ |""".stripMargin
+ reusablePerRecordStatements.add(field)
+ fieldTerm
+ }
+
+
+ /**
+ * Adds a reusable date to the beginning of the SAM of the generated [[Function]].
+ */
+ def addReusableDate(): String = {
+ val fieldTerm = s"date"
+
+ val timestamp = addReusableTimestamp()
+ val time = addReusableTime()
+
+ // adopted from org.apache.calcite.runtime.SqlFunctions.currentDate()
+ val field =
+ s"""
+ |final int $fieldTerm = (int) ($timestamp / ${DateTimeUtils.MILLIS_PER_DAY});
+ |if ($time < 0) {
+ | $fieldTerm -= 1;
+ |}
+ |""".stripMargin
+ reusablePerRecordStatements.add(field)
+ fieldTerm
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala
new file mode 100644
index 0000000..4c12003
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen
+
+import org.apache.flink.api.common.InvalidProgramException
+import org.codehaus.commons.compiler.CompileException
+import org.codehaus.janino.SimpleCompiler
+
+trait Compiler[T] {
+
+ @throws(classOf[CompileException])
+ def compile(cl: ClassLoader, name: String, code: String): Class[T] = {
+ require(cl != null, "Classloader must not be null.")
+ val compiler = new SimpleCompiler()
+ compiler.setParentClassLoader(cl)
+ try {
+ compiler.cook(code)
+ } catch {
+ case e: CompileException =>
+ throw new InvalidProgramException("Table program cannot be compiled. " +
+ "This is a bug. Please file an issue.", e)
+ }
+ compiler.getClassLoader.loadClass(name).asInstanceOf[Class[T]]
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
new file mode 100644
index 0000000..94007de
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen
+
+import java.util
+
+import org.apache.calcite.plan.RelOptPlanner
+import org.apache.calcite.rex.{RexBuilder, RexNode}
+import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.typeutils.TypeConverter
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.types.Row
+
+import scala.collection.JavaConverters._
+
+/**
+ * Evaluates constant expressions using Flink's [[CodeGenerator]].
+ */
+class ExpressionReducer(config: TableConfig)
+ extends RelOptPlanner.Executor with Compiler[MapFunction[Row, Row]] {
+
+ private val EMPTY_ROW_INFO = TypeConverter.DEFAULT_ROW_TYPE
+ private val EMPTY_ROW = new Row(0)
+
+ override def reduce(
+ rexBuilder: RexBuilder,
+ constExprs: util.List[RexNode],
+ reducedValues: util.List[RexNode]): Unit = {
+
+ val typeFactory = rexBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+
+ val literals = constExprs.asScala.map(e => (e.getType.getSqlTypeName, e)).flatMap {
+
+ // we need to cast here for RexBuilder.makeLiteral
+ case (SqlTypeName.DATE, e) =>
+ Some(
+ rexBuilder.makeCast(typeFactory.createTypeFromTypeInfo(BasicTypeInfo.INT_TYPE_INFO), e)
+ )
+ case (SqlTypeName.TIME, e) =>
+ Some(
+ rexBuilder.makeCast(typeFactory.createTypeFromTypeInfo(BasicTypeInfo.INT_TYPE_INFO), e)
+ )
+ case (SqlTypeName.TIMESTAMP, e) =>
+ Some(
+ rexBuilder.makeCast(typeFactory.createTypeFromTypeInfo(BasicTypeInfo.LONG_TYPE_INFO), e)
+ )
+
+ // we don't support object literals yet, we skip those constant expressions
+ case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) | (SqlTypeName.ARRAY, _) => None
+
+ case (_, e) => Some(e)
+ }
+
+ val literalTypes = literals.map(e => FlinkTypeFactory.toTypeInfo(e.getType))
+ val resultType = new RowTypeInfo(literalTypes: _*)
+
+ // generate MapFunction
+ val generator = new CodeGenerator(config, false, EMPTY_ROW_INFO)
+
+ val result = generator.generateResultExpression(
+ resultType,
+ resultType.getFieldNames,
+ literals)
+
+ val generatedFunction = generator.generateFunction[MapFunction[Row, Row]](
+ "ExpressionReducer",
+ classOf[MapFunction[Row, Row]],
+ s"""
+ |${result.code}
+ |return ${result.resultTerm};
+ |""".stripMargin,
+ resultType.asInstanceOf[TypeInformation[Any]])
+
+ val clazz = compile(getClass.getClassLoader, generatedFunction.name, generatedFunction.code)
+ val function = clazz.newInstance()
+
+ // execute
+ val reduced = function.map(EMPTY_ROW)
+
+ // add the reduced results or keep them unreduced
+ var i = 0
+ var reducedIdx = 0
+ while (i < constExprs.size()) {
+ val unreduced = constExprs.get(i)
+ unreduced.getType.getSqlTypeName match {
+ // we insert the original expression for object literals
+ case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY =>
+ reducedValues.add(unreduced)
+ case _ =>
+ val literal = rexBuilder.makeLiteral(
+ reduced.getField(reducedIdx),
+ unreduced.getType,
+ true)
+ reducedValues.add(literal)
+ reducedIdx += 1
+ }
+ i += 1
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Indenter.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Indenter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Indenter.scala
new file mode 100644
index 0000000..187e730
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Indenter.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.codegen
+
+class IndentStringContext(sc: StringContext) {
+ def j(args: Any*): String = {
+ val sb = new StringBuilder()
+ for ((s, a) <- sc.parts zip args) {
+ sb append s
+
+ val ind = getindent(s)
+ if (ind.nonEmpty) {
+ sb append a.toString.replaceAll("\n", "\n" + ind)
+ } else {
+ sb append a.toString
+ }
+ }
+ if (sc.parts.size > args.size) {
+ sb append sc.parts.last
+ }
+
+ sb.toString()
+ }
+
+ // get white indent after the last new line, if any
+ def getindent(str: String): String = {
+ val lastnl = str.lastIndexOf("\n")
+ if (lastnl == -1) ""
+ else {
+ val ind = str.substring(lastnl + 1)
+ if (ind.trim.isEmpty) ind // ind is all whitespace. Use this
+ else ""
+ }
+ }
+}
+
+object Indenter {
+ implicit def toISC(sc: StringContext): IndentStringContext = new IndentStringContext(sc)
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
new file mode 100644
index 0000000..649d3b2
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.codegen.calls
+
+import java.math.{BigDecimal => JBigDecimal}
+
+import org.apache.calcite.linq4j.tree.Types
+import org.apache.calcite.runtime.SqlFunctions
+import org.apache.flink.table.functions.utils.MathFunctions
+
+object BuiltInMethods {
+ val LOG10 = Types.lookupMethod(classOf[Math], "log10", classOf[Double])
+ val EXP = Types.lookupMethod(classOf[Math], "exp", classOf[Double])
+ val POWER = Types.lookupMethod(classOf[Math], "pow", classOf[Double], classOf[Double])
+ val POWER_DEC = Types.lookupMethod(
+ classOf[MathFunctions], "power", classOf[Double], classOf[JBigDecimal])
+ val LN = Types.lookupMethod(classOf[Math], "log", classOf[Double])
+ val ABS = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[Double])
+ val ABS_DEC = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[JBigDecimal])
+ val LIKE_WITH_ESCAPE = Types.lookupMethod(classOf[SqlFunctions], "like",
+ classOf[String], classOf[String], classOf[String])
+ val SIMILAR_WITH_ESCAPE = Types.lookupMethod(classOf[SqlFunctions], "similar",
+ classOf[String], classOf[String], classOf[String])
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala
new file mode 100644
index 0000000..1bc9fbb
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen.calls
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.codegen.CodeGenUtils._
+import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression}
+
+trait CallGenerator {
+
+ def generate(
+ codeGenerator: CodeGenerator,
+ operands: Seq[GeneratedExpression])
+ : GeneratedExpression
+
+}
+
+object CallGenerator {
+
+ def generateCallIfArgsNotNull(
+ nullCheck: Boolean,
+ returnType: TypeInformation[_],
+ operands: Seq[GeneratedExpression])
+ (call: (Seq[String]) => String)
+ : GeneratedExpression = {
+ val resultTerm = newName("result")
+ val nullTerm = newName("isNull")
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(returnType)
+ val defaultValue = primitiveDefaultValue(returnType)
+
+ val resultCode = if (nullCheck && operands.nonEmpty) {
+ s"""
+ |${operands.map(_.code).mkString("\n")}
+ |boolean $nullTerm = ${operands.map(_.nullTerm).mkString(" || ")};
+ |$resultTypeTerm $resultTerm;
+ |if ($nullTerm) {
+ | $resultTerm = $defaultValue;
+ |}
+ |else {
+ | $resultTerm = ${call(operands.map(_.resultTerm))};
+ |}
+ |""".stripMargin
+ } else if (nullCheck && operands.isEmpty) {
+ s"""
+ |${operands.map(_.code).mkString("\n")}
+ |boolean $nullTerm = false;
+ |$resultTypeTerm $resultTerm = ${call(operands.map(_.resultTerm))};
+ |""".stripMargin
+ } else{
+ s"""
+ |${operands.map(_.code).mkString("\n")}
+ |$resultTypeTerm $resultTerm = ${call(operands.map(_.resultTerm))};
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, nullTerm, resultCode, returnType)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala
new file mode 100644
index 0000000..d644847
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen.calls
+
+import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression}
+
+/**
+ * Generates function call to determine current time point (as date/time/timestamp) in
+ * local timezone or not.
+ */
+class CurrentTimePointCallGen(
+ targetType: TypeInformation[_],
+ local: Boolean)
+ extends CallGenerator {
+
+ override def generate(
+ codeGenerator: CodeGenerator,
+ operands: Seq[GeneratedExpression])
+ : GeneratedExpression = targetType match {
+ case SqlTimeTypeInfo.TIME if local =>
+ val time = codeGenerator.addReusableLocalTime()
+ codeGenerator.generateNonNullLiteral(targetType, time)
+
+ case SqlTimeTypeInfo.TIMESTAMP if local =>
+ val timestamp = codeGenerator.addReusableLocalTimestamp()
+ codeGenerator.generateNonNullLiteral(targetType, timestamp)
+
+ case SqlTimeTypeInfo.DATE =>
+ val date = codeGenerator.addReusableDate()
+ codeGenerator.generateNonNullLiteral(targetType, date)
+
+ case SqlTimeTypeInfo.TIME =>
+ val time = codeGenerator.addReusableTime()
+ codeGenerator.generateNonNullLiteral(targetType, time)
+
+ case SqlTimeTypeInfo.TIMESTAMP =>
+ val timestamp = codeGenerator.addReusableTimestamp()
+ codeGenerator.generateNonNullLiteral(targetType, timestamp)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FloorCeilCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FloorCeilCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FloorCeilCallGen.scala
new file mode 100644
index 0000000..dfbb436
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FloorCeilCallGen.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen.calls
+
+import java.lang.reflect.Method
+
+import org.apache.calcite.avatica.util.TimeUnitRange
+import org.apache.calcite.avatica.util.TimeUnitRange.{MONTH, YEAR}
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO}
+import org.apache.flink.table.codegen.CodeGenUtils.{getEnum, primitiveTypeTermForTypeInfo, qualifyMethod}
+import org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull
+import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression}
+
+/**
+ * Generates floor/ceil function calls.
+ */
+class FloorCeilCallGen(
+ arithmeticMethod: Method,
+ temporalMethod: Option[Method] = None)
+ extends MultiTypeMethodCallGen(arithmeticMethod) {
+
+ override def generate(
+ codeGenerator: CodeGenerator,
+ operands: Seq[GeneratedExpression])
+ : GeneratedExpression = operands.size match {
+ // arithmetic
+ case 1 =>
+ operands.head.resultType match {
+ case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO | BIG_DEC_TYPE_INFO =>
+ super.generate(codeGenerator, operands)
+ case _ =>
+ operands.head // no floor/ceil necessary
+ }
+
+ // temporal
+ case 2 =>
+ val operand = operands.head
+ val unit = getEnum(operands(1)).asInstanceOf[TimeUnitRange]
+ val internalType = primitiveTypeTermForTypeInfo(operand.resultType)
+
+ generateCallIfArgsNotNull(codeGenerator.nullCheck, operand.resultType, operands) {
+ (terms) =>
+ unit match {
+ case YEAR | MONTH =>
+ s"""
+ |($internalType) ${qualifyMethod(temporalMethod.get)}(${terms(1)}, ${terms.head})
+ |""".stripMargin
+ case _ =>
+ s"""
+ |${qualifyMethod(arithmeticMethod)}(
+ | ($internalType) ${terms.head},
+ | ($internalType) ${unit.startUnit.multiplier.intValue()})
+ |""".stripMargin
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
new file mode 100644
index 0000000..dfc9055
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
@@ -0,0 +1,369 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.codegen.calls
+
+import java.lang.reflect.Method
+
+import org.apache.calcite.avatica.util.TimeUnitRange
+import org.apache.calcite.sql.SqlOperator
+import org.apache.calcite.sql.fun.SqlStdOperatorTable._
+import org.apache.calcite.sql.fun.SqlTrimFunction
+import org.apache.calcite.util.BuiltInMethod
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.table.functions.utils.{TableSqlFunction, ScalarSqlFunction}
+
+import scala.collection.mutable
+
+/**
+ * Global hub for user-defined and built-in advanced SQL functions.
+ */
+object FunctionGenerator {
+
+ private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] =
+ mutable.Map()
+
+ // ----------------------------------------------------------------------------------------------
+ // String functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunction(
+ TRIM,
+ Seq(new GenericTypeInfo(classOf[SqlTrimFunction.Flag]), STRING_TYPE_INFO, STRING_TYPE_INFO),
+ new TrimCallGen())
+
+ addSqlFunctionMethod(
+ CHAR_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ CHARACTER_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ UPPER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.UPPER.method)
+
+ addSqlFunctionMethod(
+ LOWER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.LOWER.method)
+
+ addSqlFunctionMethod(
+ INITCAP,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.INITCAP.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.LIKE_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.SIMILAR_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ POSITION,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.POSITION.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ // ----------------------------------------------------------------------------------------------
+ // Arithmetic functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ LOG10,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LOG10)
+
+ addSqlFunctionMethod(
+ LN,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LN)
+
+ addSqlFunctionMethod(
+ EXP,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.EXP)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, BIG_DEC_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER_DEC)
+
+ addSqlFunction(
+ ABS,
+ Seq(DOUBLE_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS))
+
+ addSqlFunction(
+ ABS,
+ Seq(BIG_DEC_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ // ----------------------------------------------------------------------------------------------
+ // Temporal functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), LONG_TYPE_INFO),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), SqlTimeTypeInfo.DATE),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_FLOOR.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_CEIL.method)))
+
+ addSqlFunction(
+ CURRENT_DATE,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.DATE, local = false))
+
+ addSqlFunction(
+ CURRENT_TIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = false))
+
+ addSqlFunction(
+ CURRENT_TIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = false))
+
+ addSqlFunction(
+ LOCALTIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = true))
+
+ addSqlFunction(
+ LOCALTIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = true))
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Returns a [[CallGenerator]] that generates all required code for calling the given
+ * [[SqlOperator]].
+ *
+ * @param sqlOperator SQL operator (might be overloaded)
+ * @param operandTypes actual operand types
+ * @param resultType expected return type
+ * @return [[CallGenerator]]
+ */
+ def getCallGenerator(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ resultType: TypeInformation[_])
+ : Option[CallGenerator] = sqlOperator match {
+
+ // user-defined scalar function
+ case ssf: ScalarSqlFunction =>
+ Some(
+ new ScalarFunctionCallGen(
+ ssf.getScalarFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // user-defined table function
+ case tsf: TableSqlFunction =>
+ Some(
+ new TableFunctionCallGen(
+ tsf.getTableFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // built-in scalar function
+ case _ =>
+ sqlFunctions.get((sqlOperator, operandTypes))
+ .orElse(sqlFunctions.find(entry => entry._1._1 == sqlOperator
+ && entry._1._2.length == operandTypes.length
+ && entry._1._2.zip(operandTypes).forall {
+ case (x: BasicTypeInfo[_], y: BasicTypeInfo[_]) => y.shouldAutocastTo(x) || x == y
+ case _ => false
+ }).map(_._2))
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private def addSqlFunctionMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ returnType: TypeInformation[_],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = new MethodCallGen(returnType, method)
+ }
+
+ private def addSqlFunctionNotMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) =
+ new NotCallGenerator(new MethodCallGen(BOOLEAN_TYPE_INFO, method))
+ }
+
+ private def addSqlFunction(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ callGenerator: CallGenerator)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = callGenerator
+ }
+
+}