You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/11/22 07:05:54 UTC

[flink] 03/03: [FLINK-24781][table-planner] Refactor cast of literals to use CastExecutor

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9f7eef293f723800945a9759c50adbf8786a2bd4
Author: slinkydeveloper <fr...@gmail.com>
AuthorDate: Tue Nov 16 10:48:08 2021 +0100

    [FLINK-24781][table-planner] Refactor cast of literals to use CastExecutor
    
    Signed-off-by: slinkydeveloper <fr...@gmail.com>
    
    This closes #17800.
---
 .../CodeGeneratedExpressionCastExecutor.java       |  3 +-
 .../flink/table/planner/codegen/CodeGenUtils.scala | 26 ++++++-
 .../table/planner/codegen/GenerateUtils.scala      | 16 ----
 .../planner/codegen/calls/BuiltInMethods.scala     |  1 -
 .../table/planner/codegen/calls/IfCallGen.scala    |  7 +-
 .../planner/codegen/calls/ScalarOperatorGens.scala | 89 ++++++++++++----------
 .../validation/ScalarOperatorsValidationTest.scala | 12 +--
 7 files changed, 85 insertions(+), 69 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
index 7c361ac..6e57593 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
@@ -57,7 +57,8 @@ class CodeGeneratedExpressionCastExecutor<IN, OUT> implements CastExecutor<IN, O
                 throw (TableException) e.getCause();
             }
             throw new TableException(
-                    "Cannot execute the compiled expression for an unknown cause", e);
+                    "Cannot execute the compiled expression for an unknown cause. " + e.getCause(),
+                    e);
         }
     }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
index 22bb463..b21d097 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
@@ -21,7 +21,6 @@ package org.apache.flink.table.planner.codegen
 import java.lang.reflect.Method
 import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort}
 import java.util.concurrent.atomic.AtomicLong
-
 import org.apache.flink.api.common.ExecutionConfig
 import org.apache.flink.api.common.functions.RuntimeContext
 import org.apache.flink.core.memory.MemorySegment
@@ -33,10 +32,10 @@ import org.apache.flink.table.data.util.DataFormatConverters.IdentityConverter
 import org.apache.flink.table.data.utils.JoinedRowData
 import org.apache.flink.table.functions.UserDefinedFunction
 import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField}
+import org.apache.flink.table.planner.codegen.calls.BuiltInMethods.BINARY_STRING_DATA_FROM_STRING
 import org.apache.flink.table.runtime.dataview.StateDataViewStore
 import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
 import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
-import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
 import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
 import org.apache.flink.table.runtime.util.{MurmurHashUtil, TimeWindowUtil}
 import org.apache.flink.table.types.DataType
@@ -46,6 +45,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getPrecision, getScale}
 import org.apache.flink.table.types.logical.utils.LogicalTypeUtils.toInternalConversionClass
 import org.apache.flink.table.types.utils.DataTypeUtils.isInternal
+import org.apache.flink.table.utils.EncodingUtils
 import org.apache.flink.types.{Row, RowKind}
 
 import scala.annotation.tailrec
@@ -195,6 +195,28 @@ object CodeGenUtils {
     case _ => boxedTypeTermForType(t)
   }
 
+  /**
+   * Converts values to stringified representation to include in the codegen.
+   *
+   * This method doesn't support complex types.
+   */
+  def primitiveLiteralForType(value: Any): String = value match {
+    // ordered by type root definition
+    case _: JBoolean => value.toString
+    case _: JByte => s"((byte)$value)"
+    case _: JShort => s"((short)$value)"
+    case _: JInt => value.toString
+    case _: JLong => value.toString + "L"
+    case _: JFloat => value.toString + "f"
+    case _: JDouble => value.toString + "d"
+    case sd: StringData =>
+      qualifyMethod(BINARY_STRING_DATA_FROM_STRING) + "(\"" +
+        EncodingUtils.escapeJava(sd.toString) + "\")"
+    case td: TimestampData =>
+      s"$TIMESTAMP_DATA.fromEpochMillis(${td.getMillisecond}L, ${td.getNanoOfMillisecond})"
+    case _ => throw new IllegalArgumentException("Illegal literal type: " + value.getClass)
+  }
+
   @tailrec
   def boxedTypeTermForType(t: LogicalType): String = t.getTypeRoot match {
     // ordered by type root definition
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
index d113953..cc612ac 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
@@ -142,22 +142,6 @@ object GenerateUtils {
 
 
   /**
-    * Generates a string result call with auxiliary statements and result expression.
-    * This will convert the String result to BinaryStringData.
-    */
-  def generateStringResultCallWithStmtIfArgsNotNull(
-      ctx: CodeGeneratorContext,
-      operands: Seq[GeneratedExpression],
-      returnType: LogicalType)
-      (call: Seq[String] => (String, String)): GeneratedExpression = {
-    generateCallWithStmtIfArgsNotNull(ctx, returnType, operands) {
-      args =>
-        val (stmt, result) = call(args)
-        (stmt, s"$BINARY_STRING.fromString($result)")
-    }
-  }
-
-  /**
     * Generates a call with the nullable args.
     */
   def generateCallIfArgsNullable(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
index 308826d..824f362 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
@@ -29,7 +29,6 @@ import org.apache.flink.table.data.binary.{BinaryStringData, BinaryStringDataUti
 
 import java.lang.reflect.Method
 import java.lang.{Byte => JByte, Integer => JInteger, Long => JLong, Short => JShort}
-import java.time.ZoneId
 import java.util.TimeZone
 
 object BuiltInMethods {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
index af8061c..5fe1dd1 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
@@ -19,10 +19,9 @@
 package org.apache.flink.table.planner.codegen.calls
 
 import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, primitiveDefaultValue, primitiveTypeTermForType}
-import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCastContext
-import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression}
+import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCodegenCastContext
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression}
 import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, ExpressionCodeGeneratorCastRule}
-import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
 import org.apache.flink.table.types.logical.LogicalType
 
 /**
@@ -86,7 +85,7 @@ class IfCallGen() extends CallGenerator {
       rule match {
         case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] =>
           codeGeneratorCastRule.generateExpression(
-            toCastContext(ctx),
+            toCodegenCastContext(ctx),
             expr.resultTerm,
             expr.resultType,
             targetType
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
index 8554e4f..dc8bb63 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
@@ -20,13 +20,14 @@ package org.apache.flink.table.planner.codegen.calls
 
 import org.apache.flink.table.api.{TableException, ValidationException}
 import org.apache.flink.table.data.binary.BinaryArrayData
-import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule}
+import org.apache.flink.table.planner.functions.casting.{CastRule, CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule}
 import org.apache.flink.table.data.util.MapDataUtil
+import org.apache.flink.table.data.utils.CastExecutor
 import org.apache.flink.table.data.writer.{BinaryArrayWriter, BinaryRowWriter}
 import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull, binaryWriterWriteField, binaryWriterWriteNull, _}
 import org.apache.flink.table.planner.codegen.GenerateUtils._
 import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE}
-import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression}
 import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
 import org.apache.flink.table.runtime.functions.SqlFunctionUtils
 import org.apache.flink.table.runtime.types.PlannerTypeUtils
@@ -42,6 +43,7 @@ import org.apache.flink.table.utils.DateTimeUtils
 import org.apache.flink.util.Preconditions.checkArgument
 import org.apache.flink.table.utils.DateTimeUtils.MILLIS_PER_DAY
 
+import java.time.ZoneId
 import java.util.Arrays.asList
 import scala.collection.JavaConversions._
 
@@ -487,7 +489,7 @@ object ScalarOperatorGens {
     // for performance, we cast literal string to literal time.
     else if (isTimePoint(left.resultType) && isCharacterString(right.resultType)) {
       if (right.literal) {
-        generateEquals(ctx, left, generateCastStringLiteralToDateTime(ctx, right, left.resultType))
+        generateEquals(ctx, left, generateCastLiteral(ctx, right, left.resultType))
       } else {
         generateEquals(ctx, left, generateCast(ctx, right, left.resultType))
       }
@@ -496,7 +498,7 @@ object ScalarOperatorGens {
       if (left.literal) {
         generateEquals(
           ctx,
-          generateCastStringLiteralToDateTime(ctx, left, right.resultType),
+          generateCastLiteral(ctx, left, right.resultType),
           right)
       } else {
         generateEquals(ctx, generateCast(ctx, left, right.resultType), right)
@@ -946,7 +948,7 @@ object ScalarOperatorGens {
 
         // Generate the code block
         val castCodeBlock = codeGeneratorCastRule.generateCodeBlock(
-          toCastContext(ctx),
+          toCodegenCastContext(ctx),
           operand.resultTerm,
           operand.nullTerm,
           inputType,
@@ -1942,42 +1944,43 @@ object ScalarOperatorGens {
     }
   }
 
-  private def generateCastStringLiteralToDateTime(
-      ctx: CodeGeneratorContext,
-      stringLiteral: GeneratedExpression,
-      expectType: LogicalType): GeneratedExpression = {
-    checkArgument(stringLiteral.literal)
-    if (java.lang.Boolean.valueOf(stringLiteral.nullTerm)) {
-      return generateNullLiteral(expectType, nullCheck = true)
+  /**
+   * This method supports casting literals to non-composite types (primitives, strings, date time).
+   * Every cast result is declared as class member, in order to be able to reuse it.
+   */
+  private def generateCastLiteral(
+     ctx: CodeGeneratorContext,
+     literalExpr: GeneratedExpression,
+     resultType: LogicalType): GeneratedExpression = {
+    checkArgument(literalExpr.literal)
+    if (java.lang.Boolean.valueOf(literalExpr.nullTerm)) {
+      return generateNullLiteral(resultType, nullCheck = true)
     }
 
-    val stringValue = stringLiteral.literalValue.get.toString
-    val literalCode = expectType.getTypeRoot match {
-      case DATE =>
-        DateTimeUtils.dateStringToUnixDate(stringValue) match {
-          case null => throw new ValidationException(s"String '$stringValue' is not a valid date")
-          case v => v
-        }
-      case TIME_WITHOUT_TIME_ZONE =>
-        DateTimeUtils.timeStringToUnixDate(stringValue) match {
-          case null => throw new ValidationException(s"String '$stringValue' is not a valid time")
-          case v => v
-        }
-      case TIMESTAMP_WITHOUT_TIME_ZONE =>
-        DateTimeUtils.toTimestampData(stringValue) match {
-          case null =>
-            throw new ValidationException(s"String '$stringValue' is not a valid timestamp")
-          case v => s"${CodeGenUtils.TIMESTAMP_DATA}.fromEpochMillis(" +
-            s"${v.getMillisecond}L, ${v.getNanoOfMillisecond})"
-        }
-      case _ => throw new UnsupportedOperationException
+    val castExecutor = CastRuleProvider.create(
+      toCastContext(ctx),
+      literalExpr.resultType,
+      resultType
+    ).asInstanceOf[CastExecutor[Any, Any]]
+
+    if (castExecutor == null) {
+      throw new CodeGenException(
+        s"Unsupported casting from ${literalExpr.resultType} to $resultType")
     }
 
-    val typeTerm = primitiveTypeTermForType(expectType)
-    val resultTerm = newName("stringToTime")
-    val stmt = s"$typeTerm $resultTerm = $literalCode;"
-    ctx.addReusableMember(stmt)
-    GeneratedExpression(resultTerm, "false", "", expectType)
+    try {
+      val result = castExecutor.cast(literalExpr.literalValue.get)
+      val resultTerm = newName("stringToTime")
+
+      val declStmt =
+        s"${primitiveTypeTermForType(resultType)} $resultTerm = ${primitiveLiteralForType(result)};"
+
+      ctx.addReusableMember(declStmt)
+      GeneratedExpression(resultTerm, "false", "", resultType, Some(result))
+    } catch {
+      case e: Throwable =>
+        throw new ValidationException("Error when casting literal: " + e.getMessage, e)
+    }
   }
 
   private def generateArrayComparison(
@@ -2169,7 +2172,7 @@ object ScalarOperatorGens {
     rule match {
       case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] =>
         operandTerm => codeGeneratorCastRule.generateExpression(
-          toCastContext(ctx),
+          toCodegenCastContext(ctx),
           operandTerm,
           operandType,
           resultType
@@ -2179,7 +2182,7 @@ object ScalarOperatorGens {
     }
   }
 
-  def toCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = {
+  def toCodegenCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = {
     new CodeGeneratorCastRule.Context {
       override def getSessionTimeZoneTerm: String = ctx.addReusableSessionTimeZone()
       override def declareVariable(ty: String, variablePrefix: String): String =
@@ -2193,4 +2196,12 @@ object ScalarOperatorGens {
     }
   }
 
+  def toCastContext(ctx: CodeGeneratorContext): CastRule.Context = {
+    new CastRule.Context {
+      override def getSessionZoneId: ZoneId = ctx.tableConfig.getLocalTimeZone
+
+      override def getClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader
+    }
+  }
+
 }
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
index 4b27008..5fc4b72 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
@@ -88,24 +88,24 @@ class ScalarOperatorsValidationTest extends ScalarOperatorsTestBase {
   @Test
   def testTemporalTypeEqualsInvalidStringLiteral(): Unit = {
     testExpectedSqlException(
-      "f15 = 'invalid'", "is not a valid date",
+      "f15 = 'invalid'", "java.time.DateTimeException",
       classOf[ValidationException])
     testExpectedSqlException(
-      "'invalid' = f15", "is not a valid date",
+      "'invalid' = f15", "java.time.DateTimeException",
       classOf[ValidationException])
 
     testExpectedSqlException(
-      "f21 = 'invalid'", "is not a valid time",
+      "f21 = 'invalid'", "java.time.DateTimeException",
       classOf[ValidationException])
     testExpectedSqlException(
-      "'invalid' = f21", "is not a valid time",
+      "'invalid' = f21", "java.time.DateTimeException",
       classOf[ValidationException])
 
     testExpectedSqlException(
-      "f22 = 'invalid'", "is not a valid timestamp",
+      "f22 = 'invalid'", "java.time.DateTimeException",
       classOf[ValidationException])
     testExpectedSqlException(
-      "'invalid' = f22", "is not a valid timestamp",
+      "'invalid' = f22", "java.time.DateTimeException",
       classOf[ValidationException])
   }
 }