You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/11/22 07:05:54 UTC
[flink] 03/03: [FLINK-24781][table-planner] Refactor cast of literals to use CastExecutor
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 9f7eef293f723800945a9759c50adbf8786a2bd4
Author: slinkydeveloper <fr...@gmail.com>
AuthorDate: Tue Nov 16 10:48:08 2021 +0100
[FLINK-24781][table-planner] Refactor cast of literals to use CastExecutor
Signed-off-by: slinkydeveloper <fr...@gmail.com>
This closes #17800.
---
.../CodeGeneratedExpressionCastExecutor.java | 3 +-
.../flink/table/planner/codegen/CodeGenUtils.scala | 26 ++++++-
.../table/planner/codegen/GenerateUtils.scala | 16 ----
.../planner/codegen/calls/BuiltInMethods.scala | 1 -
.../table/planner/codegen/calls/IfCallGen.scala | 7 +-
.../planner/codegen/calls/ScalarOperatorGens.scala | 89 ++++++++++++----------
.../validation/ScalarOperatorsValidationTest.scala | 12 +--
7 files changed, 85 insertions(+), 69 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
index 7c361ac..6e57593 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java
@@ -57,7 +57,8 @@ class CodeGeneratedExpressionCastExecutor<IN, OUT> implements CastExecutor<IN, O
throw (TableException) e.getCause();
}
throw new TableException(
- "Cannot execute the compiled expression for an unknown cause", e);
+ "Cannot execute the compiled expression for an unknown cause. " + e.getCause(),
+ e);
}
}
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
index 22bb463..b21d097 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
@@ -21,7 +21,6 @@ package org.apache.flink.table.planner.codegen
import java.lang.reflect.Method
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort}
import java.util.concurrent.atomic.AtomicLong
-
import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.functions.RuntimeContext
import org.apache.flink.core.memory.MemorySegment
@@ -33,10 +32,10 @@ import org.apache.flink.table.data.util.DataFormatConverters.IdentityConverter
import org.apache.flink.table.data.utils.JoinedRowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField}
+import org.apache.flink.table.planner.codegen.calls.BuiltInMethods.BINARY_STRING_DATA_FROM_STRING
import org.apache.flink.table.runtime.dataview.StateDataViewStore
import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
-import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
import org.apache.flink.table.runtime.util.{MurmurHashUtil, TimeWindowUtil}
import org.apache.flink.table.types.DataType
@@ -46,6 +45,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getPrecision, getScale}
import org.apache.flink.table.types.logical.utils.LogicalTypeUtils.toInternalConversionClass
import org.apache.flink.table.types.utils.DataTypeUtils.isInternal
+import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.types.{Row, RowKind}
import scala.annotation.tailrec
@@ -195,6 +195,28 @@ object CodeGenUtils {
case _ => boxedTypeTermForType(t)
}
+ /**
+ * Converts values to stringified representation to include in the codegen.
+ *
+ * This method doesn't support complex types.
+ */
+ def primitiveLiteralForType(value: Any): String = value match {
+ // ordered by type root definition
+ case _: JBoolean => value.toString
+ case _: JByte => s"((byte)$value)"
+ case _: JShort => s"((short)$value)"
+ case _: JInt => value.toString
+ case _: JLong => value.toString + "L"
+ case _: JFloat => value.toString + "f"
+ case _: JDouble => value.toString + "d"
+ case sd: StringData =>
+ qualifyMethod(BINARY_STRING_DATA_FROM_STRING) + "(\"" +
+ EncodingUtils.escapeJava(sd.toString) + "\")"
+ case td: TimestampData =>
+ s"$TIMESTAMP_DATA.fromEpochMillis(${td.getMillisecond}L, ${td.getNanoOfMillisecond})"
+ case _ => throw new IllegalArgumentException("Illegal literal type: " + value.getClass)
+ }
+
@tailrec
def boxedTypeTermForType(t: LogicalType): String = t.getTypeRoot match {
// ordered by type root definition
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
index d113953..cc612ac 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
@@ -142,22 +142,6 @@ object GenerateUtils {
/**
- * Generates a string result call with auxiliary statements and result expression.
- * This will convert the String result to BinaryStringData.
- */
- def generateStringResultCallWithStmtIfArgsNotNull(
- ctx: CodeGeneratorContext,
- operands: Seq[GeneratedExpression],
- returnType: LogicalType)
- (call: Seq[String] => (String, String)): GeneratedExpression = {
- generateCallWithStmtIfArgsNotNull(ctx, returnType, operands) {
- args =>
- val (stmt, result) = call(args)
- (stmt, s"$BINARY_STRING.fromString($result)")
- }
- }
-
- /**
* Generates a call with the nullable args.
*/
def generateCallIfArgsNullable(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
index 308826d..824f362 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala
@@ -29,7 +29,6 @@ import org.apache.flink.table.data.binary.{BinaryStringData, BinaryStringDataUti
import java.lang.reflect.Method
import java.lang.{Byte => JByte, Integer => JInteger, Long => JLong, Short => JShort}
-import java.time.ZoneId
import java.util.TimeZone
object BuiltInMethods {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
index af8061c..5fe1dd1 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
@@ -19,10 +19,9 @@
package org.apache.flink.table.planner.codegen.calls
import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, primitiveDefaultValue, primitiveTypeTermForType}
-import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCastContext
-import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression}
+import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCodegenCastContext
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, ExpressionCodeGeneratorCastRule}
-import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.types.logical.LogicalType
/**
@@ -86,7 +85,7 @@ class IfCallGen() extends CallGenerator {
rule match {
case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] =>
codeGeneratorCastRule.generateExpression(
- toCastContext(ctx),
+ toCodegenCastContext(ctx),
expr.resultTerm,
expr.resultType,
targetType
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
index 8554e4f..dc8bb63 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
@@ -20,13 +20,14 @@ package org.apache.flink.table.planner.codegen.calls
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.data.binary.BinaryArrayData
-import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule}
+import org.apache.flink.table.planner.functions.casting.{CastRule, CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule}
import org.apache.flink.table.data.util.MapDataUtil
+import org.apache.flink.table.data.utils.CastExecutor
import org.apache.flink.table.data.writer.{BinaryArrayWriter, BinaryRowWriter}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull, binaryWriterWriteField, binaryWriterWriteNull, _}
import org.apache.flink.table.planner.codegen.GenerateUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE}
-import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.functions.SqlFunctionUtils
import org.apache.flink.table.runtime.types.PlannerTypeUtils
@@ -42,6 +43,7 @@ import org.apache.flink.table.utils.DateTimeUtils
import org.apache.flink.util.Preconditions.checkArgument
import org.apache.flink.table.utils.DateTimeUtils.MILLIS_PER_DAY
+import java.time.ZoneId
import java.util.Arrays.asList
import scala.collection.JavaConversions._
@@ -487,7 +489,7 @@ object ScalarOperatorGens {
// for performance, we cast literal string to literal time.
else if (isTimePoint(left.resultType) && isCharacterString(right.resultType)) {
if (right.literal) {
- generateEquals(ctx, left, generateCastStringLiteralToDateTime(ctx, right, left.resultType))
+ generateEquals(ctx, left, generateCastLiteral(ctx, right, left.resultType))
} else {
generateEquals(ctx, left, generateCast(ctx, right, left.resultType))
}
@@ -496,7 +498,7 @@ object ScalarOperatorGens {
if (left.literal) {
generateEquals(
ctx,
- generateCastStringLiteralToDateTime(ctx, left, right.resultType),
+ generateCastLiteral(ctx, left, right.resultType),
right)
} else {
generateEquals(ctx, generateCast(ctx, left, right.resultType), right)
@@ -946,7 +948,7 @@ object ScalarOperatorGens {
// Generate the code block
val castCodeBlock = codeGeneratorCastRule.generateCodeBlock(
- toCastContext(ctx),
+ toCodegenCastContext(ctx),
operand.resultTerm,
operand.nullTerm,
inputType,
@@ -1942,42 +1944,43 @@ object ScalarOperatorGens {
}
}
- private def generateCastStringLiteralToDateTime(
- ctx: CodeGeneratorContext,
- stringLiteral: GeneratedExpression,
- expectType: LogicalType): GeneratedExpression = {
- checkArgument(stringLiteral.literal)
- if (java.lang.Boolean.valueOf(stringLiteral.nullTerm)) {
- return generateNullLiteral(expectType, nullCheck = true)
+ /**
+ * This method supports casting literals to non-composite types (primitives, strings, date time).
+ * Every cast result is declared as class member, in order to be able to reuse it.
+ */
+ private def generateCastLiteral(
+ ctx: CodeGeneratorContext,
+ literalExpr: GeneratedExpression,
+ resultType: LogicalType): GeneratedExpression = {
+ checkArgument(literalExpr.literal)
+ if (java.lang.Boolean.valueOf(literalExpr.nullTerm)) {
+ return generateNullLiteral(resultType, nullCheck = true)
}
- val stringValue = stringLiteral.literalValue.get.toString
- val literalCode = expectType.getTypeRoot match {
- case DATE =>
- DateTimeUtils.dateStringToUnixDate(stringValue) match {
- case null => throw new ValidationException(s"String '$stringValue' is not a valid date")
- case v => v
- }
- case TIME_WITHOUT_TIME_ZONE =>
- DateTimeUtils.timeStringToUnixDate(stringValue) match {
- case null => throw new ValidationException(s"String '$stringValue' is not a valid time")
- case v => v
- }
- case TIMESTAMP_WITHOUT_TIME_ZONE =>
- DateTimeUtils.toTimestampData(stringValue) match {
- case null =>
- throw new ValidationException(s"String '$stringValue' is not a valid timestamp")
- case v => s"${CodeGenUtils.TIMESTAMP_DATA}.fromEpochMillis(" +
- s"${v.getMillisecond}L, ${v.getNanoOfMillisecond})"
- }
- case _ => throw new UnsupportedOperationException
+ val castExecutor = CastRuleProvider.create(
+ toCastContext(ctx),
+ literalExpr.resultType,
+ resultType
+ ).asInstanceOf[CastExecutor[Any, Any]]
+
+ if (castExecutor == null) {
+ throw new CodeGenException(
+ s"Unsupported casting from ${literalExpr.resultType} to $resultType")
}
- val typeTerm = primitiveTypeTermForType(expectType)
- val resultTerm = newName("stringToTime")
- val stmt = s"$typeTerm $resultTerm = $literalCode;"
- ctx.addReusableMember(stmt)
- GeneratedExpression(resultTerm, "false", "", expectType)
+ try {
+ val result = castExecutor.cast(literalExpr.literalValue.get)
+ val resultTerm = newName("stringToTime")
+
+ val declStmt =
+ s"${primitiveTypeTermForType(resultType)} $resultTerm = ${primitiveLiteralForType(result)};"
+
+ ctx.addReusableMember(declStmt)
+ GeneratedExpression(resultTerm, "false", "", resultType, Some(result))
+ } catch {
+ case e: Throwable =>
+ throw new ValidationException("Error when casting literal: " + e.getMessage, e)
+ }
}
private def generateArrayComparison(
@@ -2169,7 +2172,7 @@ object ScalarOperatorGens {
rule match {
case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] =>
operandTerm => codeGeneratorCastRule.generateExpression(
- toCastContext(ctx),
+ toCodegenCastContext(ctx),
operandTerm,
operandType,
resultType
@@ -2179,7 +2182,7 @@ object ScalarOperatorGens {
}
}
- def toCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = {
+ def toCodegenCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = {
new CodeGeneratorCastRule.Context {
override def getSessionTimeZoneTerm: String = ctx.addReusableSessionTimeZone()
override def declareVariable(ty: String, variablePrefix: String): String =
@@ -2193,4 +2196,12 @@ object ScalarOperatorGens {
}
}
+ def toCastContext(ctx: CodeGeneratorContext): CastRule.Context = {
+ new CastRule.Context {
+ override def getSessionZoneId: ZoneId = ctx.tableConfig.getLocalTimeZone
+
+ override def getClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader
+ }
+ }
+
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
index 4b27008..5fc4b72 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
@@ -88,24 +88,24 @@ class ScalarOperatorsValidationTest extends ScalarOperatorsTestBase {
@Test
def testTemporalTypeEqualsInvalidStringLiteral(): Unit = {
testExpectedSqlException(
- "f15 = 'invalid'", "is not a valid date",
+ "f15 = 'invalid'", "java.time.DateTimeException",
classOf[ValidationException])
testExpectedSqlException(
- "'invalid' = f15", "is not a valid date",
+ "'invalid' = f15", "java.time.DateTimeException",
classOf[ValidationException])
testExpectedSqlException(
- "f21 = 'invalid'", "is not a valid time",
+ "f21 = 'invalid'", "java.time.DateTimeException",
classOf[ValidationException])
testExpectedSqlException(
- "'invalid' = f21", "is not a valid time",
+ "'invalid' = f21", "java.time.DateTimeException",
classOf[ValidationException])
testExpectedSqlException(
- "f22 = 'invalid'", "is not a valid timestamp",
+ "f22 = 'invalid'", "java.time.DateTimeException",
classOf[ValidationException])
testExpectedSqlException(
- "'invalid' = f22", "is not a valid timestamp",
+ "'invalid' = f22", "java.time.DateTimeException",
classOf[ValidationException])
}
}