You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2020/05/20 07:03:11 UTC
[flink] 01/04: [hotfix][table] Reduce friction around logical type
roots
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git
commit c3ff1de47cd01d7448d325c42d9ad76681e8c85d
Author: Timo Walther <tw...@apache.org>
AuthorDate: Mon May 18 11:07:22 2020 +0200
[hotfix][table] Reduce friction around logical type roots
---
.../flink/table/types/logical/LogicalTypeRoot.java | 10 +
.../types/logical/utils/LogicalTypeChecks.java | 13 +
.../types/logical/utils/LogicalTypeUtils.java | 38 +-
.../flink/table/planner/codegen/CodeGenUtils.scala | 499 ++++++++++++---------
.../planner/codegen/EqualiserCodeGenerator.scala | 17 +-
.../table/planner/codegen/ExpressionReducer.scala | 6 +-
.../table/planner/codegen/GenerateUtils.scala | 237 ++++++----
.../codegen/agg/batch/AggCodeGenHelper.scala | 34 +-
.../table/runtime/typeutils/TypeCheckUtils.java | 18 +-
9 files changed, 531 insertions(+), 341 deletions(-)
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java
index 0079a8d..e2a97f6 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java
@@ -37,6 +37,16 @@ import java.util.Set;
* {@code SYMBOL}, or {@code RAW}).
*
* <p>See the type-implementing classes for a more detailed description of each type.
+ *
+ * <p>Note to implementers: Whenever we perform a match against a type root (e.g. using a switch/case
+ * statement), it is recommended to:
+ * <ul>
+ * <li>Order the items by the type root definition in this class for easy readability.
+ * <li>Think about the behavior of all type roots for the implementation. A default fallback is
+ * dangerous when introducing a new type root in the future.
+ * <li>In many <b>runtime</b> cases, resolve the indirection of {@link #DISTINCT_TYPE}:
+ * {@code return myMethod(((DistinctType) type).getSourceType)}
+ * </ul>
*/
@PublicEvolving
public enum LogicalTypeRoot {
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
index 8a3e301..b6117f4 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
@@ -108,6 +108,9 @@ public final class LogicalTypeChecks {
/**
* Checks if the given type is a composite type.
*
+ * <p>Use {@link #getFieldCount(LogicalType)}, {@link #getFieldNames(LogicalType)},
+ * {@link #getFieldTypes(LogicalType)} for unified handling of composite types.
+ *
* @param logicalType Logical data type to check
* @return True if the type is composite type.
*/
@@ -198,6 +201,16 @@ public final class LogicalTypeChecks {
return logicalType.accept(FIELD_NAMES_EXTRACTOR);
}
+ /**
+ * Returns the field types of row and structured types.
+ */
+ public static List<LogicalType> getFieldTypes(LogicalType logicalType) {
+ if (logicalType instanceof DistinctType) {
+ return getFieldTypes(((DistinctType) logicalType).getSourceType());
+ }
+ return logicalType.getChildren();
+ }
+
private LogicalTypeChecks() {
// no instantiation
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java
index 5e8be86..033d711 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java
@@ -26,6 +26,7 @@ import org.apache.flink.table.data.RawValueData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.TimestampData;
+import org.apache.flink.table.types.logical.DistinctType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.TimestampType;
@@ -45,14 +46,23 @@ public final class LogicalTypeUtils {
/**
* Returns the conversion class for the given {@link LogicalType} that is used by the
- * table runtime.
+ * table runtime as internal data structure.
*
* @see RowData
*/
public static Class<?> toInternalConversionClass(LogicalType type) {
+ // ordered by type root definition
switch (type.getTypeRoot()) {
+ case CHAR:
+ case VARCHAR:
+ return StringData.class;
case BOOLEAN:
return Boolean.class;
+ case BINARY:
+ case VARBINARY:
+ return byte[].class;
+ case DECIMAL:
+ return DecimalData.class;
case TINYINT:
return Byte.class;
case SMALLINT:
@@ -65,32 +75,32 @@ public final class LogicalTypeUtils {
case BIGINT:
case INTERVAL_DAY_TIME:
return Long.class;
- case TIMESTAMP_WITHOUT_TIME_ZONE:
- case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
- return TimestampData.class;
case FLOAT:
return Float.class;
case DOUBLE:
return Double.class;
- case CHAR:
- case VARCHAR:
- return StringData.class;
- case DECIMAL:
- return DecimalData.class;
+ case TIMESTAMP_WITHOUT_TIME_ZONE:
+ case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
+ return TimestampData.class;
+ case TIMESTAMP_WITH_TIME_ZONE:
+ throw new UnsupportedOperationException("Unsupported type: " + type);
case ARRAY:
return ArrayData.class;
- case MAP:
case MULTISET:
+ case MAP:
return MapData.class;
case ROW:
+ case STRUCTURED_TYPE:
return RowData.class;
- case BINARY:
- case VARBINARY:
- return byte[].class;
+ case DISTINCT_TYPE:
+ return toInternalConversionClass(((DistinctType) type).getSourceType());
case RAW:
return RawValueData.class;
+ case NULL:
+ case SYMBOL:
+ case UNRESOLVED:
default:
- throw new UnsupportedOperationException("Unsupported type: " + type);
+ throw new IllegalArgumentException("Illegal type: " + type);
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
index 58b7010..6e62a3f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
@@ -40,10 +40,12 @@ import org.apache.flink.table.runtime.util.MurmurHashUtil
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
-import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getPrecision, getScale, hasRoot}
import org.apache.flink.table.types.logical.utils.LogicalTypeUtils.toInternalConversionClass
import org.apache.flink.types.{Row, RowKind}
+import scala.annotation.tailrec
+
object CodeGenUtils {
// ------------------------------- DEFAULT TERMS ------------------------------------------
@@ -161,117 +163,118 @@ object CodeGenUtils {
// works, but for boxed types we need this:
// Float a = 1.0f;
// Byte b = (byte)(float) a;
+ @tailrec
def primitiveTypeTermForType(t: LogicalType): String = t.getTypeRoot match {
- case INTEGER => "int"
- case BIGINT => "long"
- case SMALLINT => "short"
+ // ordered by type root definition
+ case BOOLEAN => "boolean"
case TINYINT => "byte"
+ case SMALLINT => "short"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => "int"
+ case BIGINT | INTERVAL_DAY_TIME => "long"
case FLOAT => "float"
case DOUBLE => "double"
- case BOOLEAN => "boolean"
-
- case DATE => "int"
- case TIME_WITHOUT_TIME_ZONE => "int"
- case INTERVAL_YEAR_MONTH => "int"
- case INTERVAL_DAY_TIME => "long"
-
+ case DISTINCT_TYPE => primitiveTypeTermForType(t.asInstanceOf[DistinctType].getSourceType)
case _ => boxedTypeTermForType(t)
}
+ @tailrec
def boxedTypeTermForType(t: LogicalType): String = t.getTypeRoot match {
- case INTEGER => className[JInt]
- case BIGINT => className[JLong]
- case SMALLINT => className[JShort]
+ // ordered by type root definition
+ case CHAR | VARCHAR => BINARY_STRING
+ case BOOLEAN => className[JBoolean]
+ case BINARY | VARBINARY => "byte[]"
+ case DECIMAL => className[DecimalData]
case TINYINT => className[JByte]
+ case SMALLINT => className[JShort]
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => className[JInt]
+ case BIGINT | INTERVAL_DAY_TIME => className[JLong]
case FLOAT => className[JFloat]
case DOUBLE => className[JDouble]
- case BOOLEAN => className[JBoolean]
-
- case DATE => className[JInt]
- case TIME_WITHOUT_TIME_ZONE => className[JInt]
- case INTERVAL_YEAR_MONTH => className[JInt]
- case INTERVAL_DAY_TIME => className[JLong]
-
- case VARCHAR | CHAR => BINARY_STRING
- case VARBINARY | BINARY => "byte[]"
-
- case DECIMAL => className[DecimalData]
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => className[TimestampData]
+ case TIMESTAMP_WITH_TIME_ZONE =>
+ throw new UnsupportedOperationException("Unsupported type: " + t)
case ARRAY => className[ArrayData]
case MULTISET | MAP => className[MapData]
- case ROW => className[RowData]
+ case ROW | STRUCTURED_TYPE => className[RowData]
case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => className[TimestampData]
-
+ case DISTINCT_TYPE => boxedTypeTermForType(t.asInstanceOf[DistinctType].getSourceType)
+ case NULL => className[JObject] // special case for untyped null literals
case RAW => className[BinaryRawValueData[_]]
-
- // special case for untyped null literals
- case NULL => className[JObject]
+ case SYMBOL | UNRESOLVED =>
+ throw new IllegalArgumentException("Illegal type: " + t)
}
/**
* Gets the default value for a primitive type, and null for generic types
*/
+ @tailrec
def primitiveDefaultValue(t: LogicalType): String = t.getTypeRoot match {
- case INTEGER | TINYINT | SMALLINT => "-1"
- case BIGINT => "-1L"
+ // ordered by type root definition
+ case CHAR | VARCHAR => s"$BINARY_STRING.EMPTY_UTF8"
+ case BOOLEAN => "false"
+ case TINYINT | SMALLINT | INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => "-1"
+ case BIGINT | INTERVAL_DAY_TIME => "-1L"
case FLOAT => "-1.0f"
case DOUBLE => "-1.0d"
- case BOOLEAN => "false"
- case VARCHAR | CHAR => s"$BINARY_STRING.EMPTY_UTF8"
- case DATE | TIME_WITHOUT_TIME_ZONE => "-1"
- case INTERVAL_YEAR_MONTH => "-1"
- case INTERVAL_DAY_TIME => "-1L"
+ case DISTINCT_TYPE => primitiveDefaultValue(t.asInstanceOf[DistinctType].getSourceType)
case _ => "null"
}
- /**
- * If it's internally compatible, don't need to DataStructure converter.
- * clazz != classOf[Row] => Row can only infer GenericType[Row].
- */
- def isInternalClass(t: DataType): Boolean = {
- val clazz = t.getConversionClass
- clazz != classOf[Object] && clazz != classOf[Row] &&
- (classOf[RowData].isAssignableFrom(clazz) ||
- clazz == toInternalConversionClass(fromDataTypeToLogicalType(t)))
- }
-
+ @tailrec
def hashCodeForType(
- ctx: CodeGeneratorContext, t: LogicalType, term: String): String = t.getTypeRoot match {
- case BOOLEAN => s"${className[JBoolean]}.hashCode($term)"
- case TINYINT => s"${className[JByte]}.hashCode($term)"
- case SMALLINT => s"${className[JShort]}.hashCode($term)"
- case INTEGER => s"${className[JInt]}.hashCode($term)"
- case BIGINT => s"${className[JLong]}.hashCode($term)"
+ ctx: CodeGeneratorContext,
+ t: LogicalType,
+ term: String)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case VARCHAR | CHAR =>
+ s"$term.hashCode()"
+ case BOOLEAN =>
+ s"${className[JBoolean]}.hashCode($term)"
+ case BINARY | VARBINARY =>
+ s"${className[MurmurHashUtil]}.hashUnsafeBytes($term, $BYTE_ARRAY_BASE_OFFSET, $term.length)"
+ case DECIMAL =>
+ s"$term.hashCode()"
+ case TINYINT =>
+ s"${className[JByte]}.hashCode($term)"
+ case SMALLINT =>
+ s"${className[JShort]}.hashCode($term)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"${className[JInt]}.hashCode($term)"
+ case BIGINT | INTERVAL_DAY_TIME => s"${className[JLong]}.hashCode($term)"
case FLOAT => s"${className[JFloat]}.hashCode($term)"
case DOUBLE => s"${className[JDouble]}.hashCode($term)"
- case VARCHAR | CHAR => s"$term.hashCode()"
- case VARBINARY | BINARY => s"${className[MurmurHashUtil]}.hashUnsafeBytes(" +
- s"$term, $BYTE_ARRAY_BASE_OFFSET, $term.length)"
- case DECIMAL => s"$term.hashCode()"
- case DATE => s"${className[JInt]}.hashCode($term)"
- case TIME_WITHOUT_TIME_ZONE => s"${className[JInt]}.hashCode($term)"
case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
s"$term.hashCode()"
- case INTERVAL_YEAR_MONTH => s"${className[JInt]}.hashCode($term)"
+ case TIMESTAMP_WITH_TIME_ZONE | ARRAY | MULTISET | MAP =>
+ throw new UnsupportedOperationException("Unsupported type: " + t)
case INTERVAL_DAY_TIME => s"${className[JLong]}.hashCode($term)"
- case ARRAY => throw new IllegalArgumentException(s"Not support type to hash: $t")
- case ROW =>
- val rowType = t.asInstanceOf[RowType]
+ case ROW | STRUCTURED_TYPE =>
+ val fieldCount = getFieldCount(t)
val subCtx = CodeGeneratorContext(ctx.tableConfig)
val genHash = HashCodeGenerator.generateRowHash(
- subCtx, rowType, "SubHashRow", (0 until rowType.getFieldCount).toArray)
+ subCtx, t, "SubHashRow", (0 until fieldCount).toArray)
ctx.addReusableInnerClass(genHash.getClassName, genHash.getCode)
val refs = ctx.addReusableObject(subCtx.references.toArray, "subRefs")
val hashFunc = newName("hashFunc")
ctx.addReusableMember(s"${classOf[HashFunction].getCanonicalName} $hashFunc;")
ctx.addReusableInitStatement(s"$hashFunc = new ${genHash.getClassName}($refs);")
s"$hashFunc.hashCode($term)"
+ case DISTINCT_TYPE =>
+ hashCodeForType(ctx, t.asInstanceOf[DistinctType].getSourceType, term)
case RAW =>
- val gt = t.asInstanceOf[TypeInformationRawType[_]]
- val serTerm = ctx.addReusableObject(
- gt.getTypeInformation.createSerializer(new ExecutionConfig), "serializer")
+ val serializer = t match {
+ case rt: RawType[_] =>
+ rt.getTypeSerializer
+ case tirt: TypeInformationRawType[_] =>
+ tirt.getTypeInformation.createSerializer(new ExecutionConfig)
+ }
+ val serTerm = ctx.addReusableObject(serializer, "serializer")
s"$BINARY_RAW_VALUE.getJavaObjectFromRawValueData($term, $serTerm).hashCode()"
+ case NULL | SYMBOL | UNRESOLVED =>
+ throw new IllegalArgumentException("Illegal type: " + t)
}
// ----------------------------------------------------------------------------------------------
@@ -406,6 +409,11 @@ object CodeGenUtils {
throw new CodeGenException("Integer expression type expected.")
}
+ def udfFieldName(udf: UserDefinedFunction): String = s"function_${udf.functionIdentifier}"
+
+ def genLogInfo(logTerm: String, format: String, argTerm: String): String =
+ s"""$logTerm.info("$format", $argTerm);"""
+
// --------------------------------------------------------------------------------
// DataFormat Operations
// --------------------------------------------------------------------------------
@@ -419,44 +427,50 @@ object CodeGenUtils {
fieldType: LogicalType) : String =
rowFieldReadAccess(ctx, index.toString, rowTerm, fieldType)
+ @tailrec
def rowFieldReadAccess(
ctx: CodeGeneratorContext,
indexTerm: String,
rowTerm: String,
- t: LogicalType) : String =
- t.getTypeRoot match {
- // primitive types
- case BOOLEAN => s"$rowTerm.getBoolean($indexTerm)"
- case TINYINT => s"$rowTerm.getByte($indexTerm)"
- case SMALLINT => s"$rowTerm.getShort($indexTerm)"
- case INTEGER => s"$rowTerm.getInt($indexTerm)"
- case BIGINT => s"$rowTerm.getLong($indexTerm)"
- case FLOAT => s"$rowTerm.getFloat($indexTerm)"
- case DOUBLE => s"$rowTerm.getDouble($indexTerm)"
- case VARCHAR | CHAR => s"(($BINARY_STRING) $rowTerm.getString($indexTerm))"
- case VARBINARY | BINARY => s"$rowTerm.getBinary($indexTerm)"
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case CHAR | VARCHAR =>
+ s"(($BINARY_STRING) $rowTerm.getString($indexTerm))"
+ case BOOLEAN =>
+ s"$rowTerm.getBoolean($indexTerm)"
+ case BINARY | VARBINARY =>
+ s"$rowTerm.getBinary($indexTerm)"
case DECIMAL =>
- val dt = t.asInstanceOf[DecimalType]
- s"$rowTerm.getDecimal($indexTerm, ${dt.getPrecision}, ${dt.getScale})"
-
- // temporal types
- case DATE => s"$rowTerm.getInt($indexTerm)"
- case TIME_WITHOUT_TIME_ZONE => s"$rowTerm.getInt($indexTerm)"
- case TIMESTAMP_WITHOUT_TIME_ZONE =>
- val dt = t.asInstanceOf[TimestampType]
- s"$rowTerm.getTimestamp($indexTerm, ${dt.getPrecision})"
- case TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
- val dt = t.asInstanceOf[LocalZonedTimestampType]
- s"$rowTerm.getTimestamp($indexTerm, ${dt.getPrecision})"
- case INTERVAL_YEAR_MONTH => s"$rowTerm.getInt($indexTerm)"
- case INTERVAL_DAY_TIME => s"$rowTerm.getLong($indexTerm)"
-
- // complex types
- case ARRAY => s"$rowTerm.getArray($indexTerm)"
- case MULTISET | MAP => s"$rowTerm.getMap($indexTerm)"
- case ROW => s"$rowTerm.getRow($indexTerm, ${t.asInstanceOf[RowType].getFieldCount})"
-
- case RAW => s"(($BINARY_RAW_VALUE) $rowTerm.getRawValue($indexTerm))"
+ s"$rowTerm.getDecimal($indexTerm, ${getPrecision(t)}, ${getScale(t)})"
+ case TINYINT =>
+ s"$rowTerm.getByte($indexTerm)"
+ case SMALLINT =>
+ s"$rowTerm.getShort($indexTerm)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"$rowTerm.getInt($indexTerm)"
+ case BIGINT | INTERVAL_DAY_TIME =>
+ s"$rowTerm.getLong($indexTerm)"
+ case FLOAT =>
+ s"$rowTerm.getFloat($indexTerm)"
+ case DOUBLE =>
+ s"$rowTerm.getDouble($indexTerm)"
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
+ s"$rowTerm.getTimestamp($indexTerm, ${getPrecision(t)})"
+ case TIMESTAMP_WITH_TIME_ZONE =>
+ throw new UnsupportedOperationException("Unsupported type: " + t)
+ case ARRAY =>
+ s"$rowTerm.getArray($indexTerm)"
+ case MULTISET | MAP =>
+ s"$rowTerm.getMap($indexTerm)"
+ case ROW | STRUCTURED_TYPE =>
+ s"$rowTerm.getRow($indexTerm, ${getFieldCount(t)})"
+ case DISTINCT_TYPE =>
+ rowFieldReadAccess(ctx, indexTerm, rowTerm, t.asInstanceOf[DistinctType].getSourceType)
+ case RAW =>
+ s"(($BINARY_RAW_VALUE) $rowTerm.getRawValue($indexTerm))"
+ case NULL | SYMBOL | UNRESOLVED =>
+ throw new IllegalArgumentException("Illegal type: " + t)
}
// -------------------------- RowData Set Field -------------------------------
@@ -549,14 +563,22 @@ object CodeGenUtils {
def binaryRowSetNull(index: Int, rowTerm: String, t: LogicalType): String =
binaryRowSetNull(index.toString, rowTerm, t)
- def binaryRowSetNull(indexTerm: String, rowTerm: String, t: LogicalType): String = t match {
- case d: DecimalType if !DecimalData.isCompact(d.getPrecision) =>
- s"$rowTerm.setDecimal($indexTerm, null, ${d.getPrecision})"
- case d: TimestampType if !TimestampData.isCompact(d.getPrecision) =>
- s"$rowTerm.setTimestamp($indexTerm, null, ${d.getPrecision})"
- case d: LocalZonedTimestampType if !TimestampData.isCompact(d.getPrecision) =>
- s"$rowTerm.setTimestamp($indexTerm, null, ${d.getPrecision})"
- case _ => s"$rowTerm.setNullAt($indexTerm)"
+ @tailrec
+ def binaryRowSetNull(
+ indexTerm: String,
+ rowTerm: String,
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case DECIMAL if !DecimalData.isCompact(getPrecision(t)) =>
+ s"$rowTerm.setDecimal($indexTerm, null, ${getPrecision(t)})"
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE
+ if !TimestampData.isCompact(getPrecision(t)) =>
+ s"$rowTerm.setTimestamp($indexTerm, null, ${getPrecision(t)})"
+ case DISTINCT_TYPE =>
+ binaryRowSetNull(indexTerm, rowTerm, t.asInstanceOf[DistinctType].getSourceType)
+ case _ =>
+ s"$rowTerm.setNullAt($indexTerm)"
}
def binaryRowFieldSetAccess(
@@ -566,75 +588,102 @@ object CodeGenUtils {
fieldValTerm: String): String =
binaryRowFieldSetAccess(index.toString, binaryRowTerm, fieldType, fieldValTerm)
+ @tailrec
def binaryRowFieldSetAccess(
index: String,
binaryRowTerm: String,
t: LogicalType,
- fieldValTerm: String): String =
- t.getTypeRoot match {
- case INTEGER => s"$binaryRowTerm.setInt($index, $fieldValTerm)"
- case BIGINT => s"$binaryRowTerm.setLong($index, $fieldValTerm)"
- case SMALLINT => s"$binaryRowTerm.setShort($index, $fieldValTerm)"
- case TINYINT => s"$binaryRowTerm.setByte($index, $fieldValTerm)"
- case FLOAT => s"$binaryRowTerm.setFloat($index, $fieldValTerm)"
- case DOUBLE => s"$binaryRowTerm.setDouble($index, $fieldValTerm)"
- case BOOLEAN => s"$binaryRowTerm.setBoolean($index, $fieldValTerm)"
- case DATE => s"$binaryRowTerm.setInt($index, $fieldValTerm)"
- case TIME_WITHOUT_TIME_ZONE => s"$binaryRowTerm.setInt($index, $fieldValTerm)"
- case TIMESTAMP_WITHOUT_TIME_ZONE =>
- val dt = t.asInstanceOf[TimestampType]
- s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${dt.getPrecision})"
- case TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
- val dt = t.asInstanceOf[LocalZonedTimestampType]
- s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${dt.getPrecision})"
- case INTERVAL_YEAR_MONTH => s"$binaryRowTerm.setInt($index, $fieldValTerm)"
- case INTERVAL_DAY_TIME => s"$binaryRowTerm.setLong($index, $fieldValTerm)"
- case DECIMAL =>
- val dt = t.asInstanceOf[DecimalType]
- s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${dt.getPrecision})"
- case _ =>
- throw new CodeGenException("Fail to find binary row field setter method of LogicalType "
- + t + ".")
- }
+ fieldValTerm: String)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case BOOLEAN =>
+ s"$binaryRowTerm.setBoolean($index, $fieldValTerm)"
+ case DECIMAL =>
+ s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${getPrecision(t)})"
+ case TINYINT =>
+ s"$binaryRowTerm.setByte($index, $fieldValTerm)"
+ case SMALLINT =>
+ s"$binaryRowTerm.setShort($index, $fieldValTerm)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"$binaryRowTerm.setInt($index, $fieldValTerm)"
+ case BIGINT | INTERVAL_DAY_TIME =>
+ s"$binaryRowTerm.setLong($index, $fieldValTerm)"
+ case FLOAT =>
+ s"$binaryRowTerm.setFloat($index, $fieldValTerm)"
+ case DOUBLE =>
+ s"$binaryRowTerm.setDouble($index, $fieldValTerm)"
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
+ s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${getPrecision(t)})"
+ case DISTINCT_TYPE =>
+ binaryRowFieldSetAccess(
+ index,
+ binaryRowTerm,
+ t.asInstanceOf[DistinctType].getSourceType,
+ fieldValTerm)
+ case _ =>
+ throw new CodeGenException(
+ "Fail to find binary row field setter method of LogicalType " + t + ".")
+ }
// -------------------------- BoxedWrapperRowData Set Field -------------------------------
+ @tailrec
def boxedWrapperRowFieldSetAccess(
rowTerm: String,
indexTerm: String,
fieldTerm: String,
- t: LogicalType): String =
- t.getTypeRoot match {
- case INTEGER => s"$rowTerm.setInt($indexTerm, $fieldTerm)"
- case BIGINT => s"$rowTerm.setLong($indexTerm, $fieldTerm)"
- case SMALLINT => s"$rowTerm.setShort($indexTerm, $fieldTerm)"
- case TINYINT => s"$rowTerm.setByte($indexTerm, $fieldTerm)"
- case FLOAT => s"$rowTerm.setFloat($indexTerm, $fieldTerm)"
- case DOUBLE => s"$rowTerm.setDouble($indexTerm, $fieldTerm)"
- case BOOLEAN => s"$rowTerm.setBoolean($indexTerm, $fieldTerm)"
- case DATE => s"$rowTerm.setInt($indexTerm, $fieldTerm)"
- case TIME_WITHOUT_TIME_ZONE => s"$rowTerm.setInt($indexTerm, $fieldTerm)"
- case INTERVAL_YEAR_MONTH => s"$rowTerm.setInt($indexTerm, $fieldTerm)"
- case INTERVAL_DAY_TIME => s"$rowTerm.setLong($indexTerm, $fieldTerm)"
- case _ => s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)"
- }
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case BOOLEAN =>
+ s"$rowTerm.setBoolean($indexTerm, $fieldTerm)"
+ case TINYINT =>
+ s"$rowTerm.setByte($indexTerm, $fieldTerm)"
+ case SMALLINT =>
+ s"$rowTerm.setShort($indexTerm, $fieldTerm)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"$rowTerm.setInt($indexTerm, $fieldTerm)"
+ case BIGINT | INTERVAL_DAY_TIME =>
+ s"$rowTerm.setLong($indexTerm, $fieldTerm)"
+ case FLOAT =>
+ s"$rowTerm.setFloat($indexTerm, $fieldTerm)"
+ case DOUBLE =>
+ s"$rowTerm.setDouble($indexTerm, $fieldTerm)"
+ case DISTINCT_TYPE =>
+ boxedWrapperRowFieldSetAccess(
+ rowTerm,
+ indexTerm,
+ fieldTerm,
+ t.asInstanceOf[DistinctType].getSourceType)
+ case _ =>
+ s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)"
+ }
// -------------------------- BinaryArray Set Access -------------------------------
+ @tailrec
def binaryArraySetNull(
index: Int,
arrayTerm: String,
- t: LogicalType): String = t.getTypeRoot match {
- case BOOLEAN => s"$arrayTerm.setNullBoolean($index)"
- case TINYINT => s"$arrayTerm.setNullByte($index)"
- case SMALLINT => s"$arrayTerm.setNullShort($index)"
- case INTEGER => s"$arrayTerm.setNullInt($index)"
- case FLOAT => s"$arrayTerm.setNullFloat($index)"
- case DOUBLE => s"$arrayTerm.setNullDouble($index)"
- case TIME_WITHOUT_TIME_ZONE => s"$arrayTerm.setNullInt($index)"
- case DATE => s"$arrayTerm.setNullInt($index)"
- case INTERVAL_YEAR_MONTH => s"$arrayTerm.setNullInt($index)"
- case _ => s"$arrayTerm.setNullLong($index)"
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case BOOLEAN =>
+ s"$arrayTerm.setNullBoolean($index)"
+ case TINYINT =>
+ s"$arrayTerm.setNullByte($index)"
+ case SMALLINT =>
+ s"$arrayTerm.setNullShort($index)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"$arrayTerm.setNullInt($index)"
+ case FLOAT =>
+ s"$arrayTerm.setNullFloat($index)"
+ case DOUBLE =>
+ s"$arrayTerm.setNullDouble($index)"
+ case DISTINCT_TYPE =>
+ binaryArraySetNull(index, arrayTerm, t)
+ case _ =>
+ s"$arrayTerm.setNullLong($index)"
}
// -------------------------- BinaryWriter Write -------------------------------
@@ -642,17 +691,22 @@ object CodeGenUtils {
def binaryWriterWriteNull(index: Int, writerTerm: String, t: LogicalType): String =
binaryWriterWriteNull(index.toString, writerTerm, t)
+ @tailrec
def binaryWriterWriteNull(
indexTerm: String,
writerTerm: String,
- t: LogicalType): String = t match {
- case d: DecimalType if !DecimalData.isCompact(d.getPrecision) =>
- s"$writerTerm.writeDecimal($indexTerm, null, ${d.getPrecision})"
- case d: TimestampType if !TimestampData.isCompact(d.getPrecision) =>
- s"$writerTerm.writeTimestamp($indexTerm, null, ${d.getPrecision})"
- case d: LocalZonedTimestampType if !TimestampData.isCompact(d.getPrecision) =>
- s"$writerTerm.writeTimestamp($indexTerm, null, ${d.getPrecision})"
- case _ => s"$writerTerm.setNullAt($indexTerm)"
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case DECIMAL if !DecimalData.isCompact(getPrecision(t)) =>
+ s"$writerTerm.writeDecimal($indexTerm, null, ${getPrecision(t)})"
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE
+ if !TimestampData.isCompact(getPrecision(t)) =>
+ s"$writerTerm.writeTimestamp($indexTerm, null, ${getPrecision(t)})"
+ case DISTINCT_TYPE =>
+ binaryWriterWriteNull(indexTerm, writerTerm, t.asInstanceOf[DistinctType].getSourceType)
+ case _ =>
+ s"$writerTerm.setNullAt($indexTerm)"
}
def binaryWriterWriteField(
@@ -663,50 +717,74 @@ object CodeGenUtils {
fieldType: LogicalType): String =
binaryWriterWriteField(ctx, index.toString, fieldValTerm, writerTerm, fieldType)
+ @tailrec
def binaryWriterWriteField(
ctx: CodeGeneratorContext,
indexTerm: String,
fieldValTerm: String,
writerTerm: String,
- t: LogicalType): String =
- t.getTypeRoot match {
- case INTEGER => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)"
- case BIGINT => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)"
- case SMALLINT => s"$writerTerm.writeShort($indexTerm, $fieldValTerm)"
- case TINYINT => s"$writerTerm.writeByte($indexTerm, $fieldValTerm)"
- case FLOAT => s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)"
- case DOUBLE => s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)"
- case BOOLEAN => s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)"
- case VARBINARY | BINARY => s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)"
- case VARCHAR | CHAR => s"$writerTerm.writeString($indexTerm, $fieldValTerm)"
- case DECIMAL =>
- val dt = t.asInstanceOf[DecimalType]
- s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${dt.getPrecision})"
- case DATE => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)"
- case TIME_WITHOUT_TIME_ZONE => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)"
- case TIMESTAMP_WITHOUT_TIME_ZONE =>
- val dt = t.asInstanceOf[TimestampType]
- s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${dt.getPrecision})"
- case TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
- val dt = t.asInstanceOf[LocalZonedTimestampType]
- s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${dt.getPrecision})"
- case INTERVAL_YEAR_MONTH => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)"
- case INTERVAL_DAY_TIME => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)"
-
- // complex types
- case ARRAY =>
- val ser = ctx.addReusableTypeSerializer(t)
- s"$writerTerm.writeArray($indexTerm, $fieldValTerm, $ser)"
- case MULTISET | MAP =>
- val ser = ctx.addReusableTypeSerializer(t)
- s"$writerTerm.writeMap($indexTerm, $fieldValTerm, $ser)"
- case ROW =>
- val ser = ctx.addReusableTypeSerializer(t)
- s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $ser)"
- case RAW =>
- val ser = ctx.addReusableTypeSerializer(t)
- s"$writerTerm.writeRawValue($indexTerm, $fieldValTerm, $ser)"
- }
+ t: LogicalType)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case CHAR | VARCHAR =>
+ s"$writerTerm.writeString($indexTerm, $fieldValTerm)"
+ case BOOLEAN =>
+ s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)"
+ case BINARY | VARBINARY =>
+ s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)"
+ case DECIMAL =>
+ s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${getPrecision(t)})"
+ case TINYINT =>
+ s"$writerTerm.writeByte($indexTerm, $fieldValTerm)"
+ case SMALLINT =>
+ s"$writerTerm.writeShort($indexTerm, $fieldValTerm)"
+ case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH =>
+ s"$writerTerm.writeInt($indexTerm, $fieldValTerm)"
+ case BIGINT | INTERVAL_DAY_TIME =>
+ s"$writerTerm.writeLong($indexTerm, $fieldValTerm)"
+ case FLOAT =>
+ s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)"
+ case DOUBLE =>
+ s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)"
+ case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
+ s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${getPrecision(t)})"
+ case TIMESTAMP_WITH_TIME_ZONE =>
+ throw new UnsupportedOperationException("Unsupported type: " + t)
+ case ARRAY =>
+ val ser = ctx.addReusableTypeSerializer(t)
+ s"$writerTerm.writeArray($indexTerm, $fieldValTerm, $ser)"
+ case MULTISET | MAP =>
+ val ser = ctx.addReusableTypeSerializer(t)
+ s"$writerTerm.writeMap($indexTerm, $fieldValTerm, $ser)"
+ case ROW | STRUCTURED_TYPE =>
+ val ser = ctx.addReusableTypeSerializer(t)
+ s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $ser)"
+ case DISTINCT_TYPE =>
+ binaryWriterWriteField(
+ ctx,
+ indexTerm,
+ fieldValTerm,
+ writerTerm,
+ t.asInstanceOf[DistinctType].getSourceType)
+ case RAW =>
+ val ser = ctx.addReusableTypeSerializer(t)
+ s"$writerTerm.writeRawValue($indexTerm, $fieldValTerm, $ser)"
+ case NULL | SYMBOL | UNRESOLVED =>
+ throw new IllegalArgumentException("Illegal type: " + t);
+ }
+
+ // -------------------------- Data Structure Conversion -------------------------------
+
+ /**
+ * If it's internally compatible, don't need to DataStructure converter.
+ * clazz != classOf[Row] => Row can only infer GenericType[Row].
+ */
+ def isInternalClass(t: DataType): Boolean = {
+ val clazz = t.getConversionClass
+ clazz != classOf[Object] && clazz != classOf[Row] &&
+ (classOf[RowData].isAssignableFrom(clazz) ||
+ clazz == toInternalConversionClass(fromDataTypeToLogicalType(t)))
+ }
private def isConverterIdentity(t: DataType): Boolean = {
DataFormatConverters.getConverterForDataType(t).isInstanceOf[IdentityConverter[_]]
@@ -808,9 +886,4 @@ object CodeGenUtils {
s"${internalExpr.nullTerm} ? null : ($externalResultTerm)"
}
}
-
- def udfFieldName(udf: UserDefinedFunction): String = s"function_${udf.functionIdentifier}"
-
- def genLogInfo(logTerm: String, format: String, argTerm: String): String =
- s"""$logTerm.info("$format", $argTerm);"""
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
index 174158d..850d55f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
@@ -24,8 +24,10 @@ import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.generateE
import org.apache.flink.table.runtime.generated.{GeneratedRecordEqualiser, RecordEqualiser}
import org.apache.flink.table.runtime.types.PlannerTypeUtils
import org.apache.flink.table.types.logical.LogicalTypeRoot._
-import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldTypes, isCompositeType}
+import org.apache.flink.table.types.logical.{DistinctType, LogicalType}
+import scala.annotation.tailrec
import scala.collection.JavaConverters._
class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
@@ -57,9 +59,9 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
// TODO merge ScalarOperatorGens.generateEquals.
val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) {
("", s"$leftFieldTerm == $rightFieldTerm")
- } else if (isRowData(fieldType)) {
+ } else if (isCompositeType(fieldType)) {
val equaliserGenerator = new EqualiserCodeGenerator(
- fieldType.asInstanceOf[RowType].getChildren.asScala.toArray)
+ getFieldTypes(fieldType).asScala.toArray)
val generatedEqualiser = equaliserGenerator
.generateRecordEqualiser("field$" + i + "GeneratedEqualiser")
val generatedEqualiserTerm = ctx.addReusableObject(
@@ -128,15 +130,14 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray)
}
+ @tailrec
private def isInternalPrimitive(t: LogicalType): Boolean = t.getTypeRoot match {
case _ if PlannerTypeUtils.isPrimitive(t) => true
- case DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH |INTERVAL_DAY_TIME => true
- case _ => false
- }
+ case DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH | INTERVAL_DAY_TIME => true
+
+ case DISTINCT_TYPE => isInternalPrimitive(t.asInstanceOf[DistinctType].getSourceType)
- private def isRowData(t: LogicalType): Boolean = t match {
- case _: RowType => true
case _ => false
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
index 950a35b..82a0122 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
@@ -30,11 +30,9 @@ import org.apache.flink.table.planner.codegen.FunctionCodeGenerator.generateFunc
import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
import org.apache.flink.table.types.logical.RowType
import org.apache.flink.table.util.TimestampStringUtils.fromLocalDateTime
-
import org.apache.calcite.avatica.util.ByteString
import org.apache.calcite.rex.{RexBuilder, RexExecutor, RexNode}
import org.apache.calcite.sql.`type`.SqlTypeName
-
import java.io.File
import scala.collection.JavaConverters._
@@ -72,7 +70,9 @@ class ExpressionReducer(
// we don't support object literals yet, we skip those constant expressions
case (SqlTypeName.ANY, _) |
+ (SqlTypeName.OTHER, _) |
(SqlTypeName.ROW, _) |
+ (SqlTypeName.STRUCTURED, _) |
(SqlTypeName.ARRAY, _) |
(SqlTypeName.MAP, _) |
(SqlTypeName.MULTISET, _) => None
@@ -133,7 +133,9 @@ class ExpressionReducer(
unreduced.getType.getSqlTypeName match {
// we insert the original expression for object literals
case SqlTypeName.ANY |
+ SqlTypeName.OTHER |
SqlTypeName.ROW |
+ SqlTypeName.STRUCTURED |
SqlTypeName.ARRAY |
SqlTypeName.MAP |
SqlTypeName.MULTISET =>
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
index 4fc52d2..9d0fe44 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala
@@ -34,12 +34,13 @@ import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.calls.CurrentTimePointCallGen
import org.apache.flink.table.planner.plan.utils.SortUtil
-import org.apache.flink.table.runtime.types.PlannerTypeUtils
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isCharacterString, isReference, isTemporal}
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getFieldTypes}
import org.apache.flink.table.util.TimestampStringUtils.toLocalDateTime
+import scala.annotation.tailrec
import scala.collection.mutable
/**
@@ -209,39 +210,47 @@ object GenerateUtils {
/**
* Generates a record declaration statement. The record can be any type of RowData or
* other types.
+ *
* @param t the record type
* @param clazz the specified class of the type (only used when RowType)
* @param recordTerm the record term to be declared
* @param recordWriterTerm the record writer term (only used when BinaryRowData type)
* @return the record declaration statement
- */
+ */
+ @tailrec
def generateRecordStatement(
t: LogicalType,
clazz: Class[_],
recordTerm: String,
- recordWriterTerm: Option[String] = None): String = {
- t match {
- case rt: RowType if clazz == classOf[BinaryRowData] =>
- val writerTerm = recordWriterTerm.getOrElse(
- throw new CodeGenException("No writer is specified when writing BinaryRowData record.")
- )
- val binaryRowWriter = className[BinaryRowWriter]
- val typeTerm = clazz.getCanonicalName
- s"""
- |final $typeTerm $recordTerm = new $typeTerm(${rt.getFieldCount});
- |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm);
- |""".stripMargin.trim
- case rt: RowType if clazz == classOf[GenericRowData] ||
- clazz == classOf[BoxedWrapperRowData] =>
- val typeTerm = clazz.getCanonicalName
- s"final $typeTerm $recordTerm = new $typeTerm(${rt.getFieldCount});"
- case _: RowType if clazz == classOf[JoinedRowData] =>
- val typeTerm = clazz.getCanonicalName
- s"final $typeTerm $recordTerm = new $typeTerm();"
- case _ =>
- val typeTerm = boxedTypeTermForType(t)
- s"final $typeTerm $recordTerm = new $typeTerm();"
- }
+ recordWriterTerm: Option[String] = None)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case ROW | STRUCTURED_TYPE if clazz == classOf[BinaryRowData] =>
+ val writerTerm = recordWriterTerm.getOrElse(
+ throw new CodeGenException("No writer is specified when writing BinaryRowData record.")
+ )
+ val binaryRowWriter = className[BinaryRowWriter]
+ val typeTerm = clazz.getCanonicalName
+ s"""
+ |final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});
+ |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm);
+ |""".stripMargin.trim
+ case ROW | STRUCTURED_TYPE if clazz == classOf[GenericRowData] ||
+ clazz == classOf[BoxedWrapperRowData] =>
+ val typeTerm = clazz.getCanonicalName
+ s"final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});"
+ case ROW | STRUCTURED_TYPE if clazz == classOf[JoinedRowData] =>
+ val typeTerm = clazz.getCanonicalName
+ s"final $typeTerm $recordTerm = new $typeTerm();"
+ case DISTINCT_TYPE =>
+ generateRecordStatement(
+ t.asInstanceOf[DistinctType].getSourceType,
+ clazz,
+ recordTerm,
+ recordWriterTerm)
+ case _ =>
+ val typeTerm = boxedTypeTermForType(t)
+ s"final $typeTerm $recordTerm = new $typeTerm();"
}
def generateNullLiteral(
@@ -273,6 +282,7 @@ object GenerateUtils {
literalValue = Some(literalValue))
}
+ @tailrec
def generateLiteral(
ctx: CodeGeneratorContext,
literalType: LogicalType,
@@ -282,10 +292,41 @@ object GenerateUtils {
}
// non-null values
literalType.getTypeRoot match {
+ // ordered by type root definition
+ case CHAR | VARCHAR =>
+ val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString)
+ val field = ctx.addReusableStringConstants(escapedValue)
+ generateNonNullLiteral(literalType, field, StringData.fromString(escapedValue))
case BOOLEAN =>
generateNonNullLiteral(literalType, literalValue.toString, literalValue)
+ case BINARY | VARBINARY =>
+ val bytesVal = literalValue.asInstanceOf[ByteString].getBytes
+ val fieldTerm = ctx.addReusableObject(
+ bytesVal, "binary", bytesVal.getClass.getCanonicalName)
+ generateNonNullLiteral(literalType, fieldTerm, bytesVal)
+
+ case DECIMAL =>
+ val dt = literalType.asInstanceOf[DecimalType]
+ val precision = dt.getPrecision
+ val scale = dt.getScale
+ val fieldTerm = newName("decimal")
+ val decimalClass = className[DecimalData]
+ val fieldDecimal =
+ s"""
+ |$decimalClass $fieldTerm =
+ | $DECIMAL_UTIL.castFrom("${literalValue.toString}", $precision, $scale);
+ |""".stripMargin
+ ctx.addReusableMember(fieldDecimal)
+ val value = DecimalData.fromBigDecimal(
+ literalValue.asInstanceOf[JBigDecimal], precision, scale)
+ if (value == null) {
+ generateNullLiteral(literalType, ctx.nullCheck)
+ } else {
+ generateNonNullLiteral(literalType, fieldTerm, value)
+ }
+
case TINYINT =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
generateNonNullLiteral(literalType, decimal.byteValue().toString, decimal.byteValue())
@@ -335,36 +376,6 @@ object GenerateUtils {
case _ => generateNonNullLiteral(
literalType, doubleValue.toString + "d", doubleValue)
}
- case DECIMAL =>
- val dt = literalType.asInstanceOf[DecimalType]
- val precision = dt.getPrecision
- val scale = dt.getScale
- val fieldTerm = newName("decimal")
- val decimalClass = className[DecimalData]
- val fieldDecimal =
- s"""
- |$decimalClass $fieldTerm =
- | $DECIMAL_UTIL.castFrom("${literalValue.toString}", $precision, $scale);
- |""".stripMargin
- ctx.addReusableMember(fieldDecimal)
- val value = DecimalData.fromBigDecimal(
- literalValue.asInstanceOf[JBigDecimal], precision, scale)
- if (value == null) {
- generateNullLiteral(literalType, ctx.nullCheck)
- } else {
- generateNonNullLiteral(literalType, fieldTerm, value)
- }
-
- case VARCHAR | CHAR =>
- val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString)
- val field = ctx.addReusableStringConstants(escapedValue)
- generateNonNullLiteral(literalType, field, StringData.fromString(escapedValue))
-
- case VARBINARY | BINARY =>
- val bytesVal = literalValue.asInstanceOf[ByteString].getBytes
- val fieldTerm = ctx.addReusableObject(
- bytesVal, "binary", bytesVal.getClass.getCanonicalName)
- generateNonNullLiteral(literalType, fieldTerm, bytesVal)
case DATE =>
generateNonNullLiteral(literalType, literalValue.toString, literalValue)
@@ -384,6 +395,9 @@ object GenerateUtils {
ctx.addReusableMember(fieldTimestamp)
generateNonNullLiteral(literalType, fieldTerm, ts)
+ case TIMESTAMP_WITH_TIME_ZONE =>
+ throw new UnsupportedOperationException("Unsupported type: " + literalType)
+
case TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
val fieldTerm = newName("timestampWithLocalZone")
val ins =
@@ -420,13 +434,19 @@ object GenerateUtils {
s"Decimal '$decimal' can not be converted to interval of milliseconds.")
}
+ case DISTINCT_TYPE =>
+ generateLiteral(ctx, literalType.asInstanceOf[DistinctType].getSourceType, literalValue)
+
// Symbol type for special flags e.g. TRIM's BOTH, LEADING, TRAILING
case RAW if literalType.asInstanceOf[TypeInformationRawType[_]]
.getTypeInformation.getTypeClass.isAssignableFrom(classOf[Enum[_]]) =>
generateSymbol(literalValue.asInstanceOf[Enum[_]])
- case t@_ =>
- throw new CodeGenException(s"Type not supported: $t")
+ case SYMBOL =>
+ throw new UnsupportedOperationException() // TODO support symbol?
+
+ case ARRAY | MULTISET | MAP | ROW | STRUCTURED_TYPE | NULL | UNRESOLVED =>
+ throw new CodeGenException(s"Type not supported: $literalType")
}
}
@@ -546,10 +566,15 @@ object GenerateUtils {
index: Int,
deepCopy: Boolean = false): GeneratedExpression = {
- val fieldType = inputType match {
- case ct: RowType => ct.getTypeAt(index)
- case _ => inputType
+ @tailrec
+ def getFieldType(t: LogicalType, pos: Int): LogicalType = t.getTypeRoot match {
+ // ordered by type root definition
+ case ROW | STRUCTURED_TYPE => t.getChildren.get(pos)
+ case DISTINCT_TYPE => getFieldType(t.asInstanceOf[DistinctType].getSourceType, pos)
+ case _ => t
}
+
+ val fieldType = getFieldType(inputType, index)
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables(
@@ -636,14 +661,16 @@ object GenerateUtils {
}
}
+ @tailrec
def generateFieldAccess(
ctx: CodeGeneratorContext,
inputType: LogicalType,
inputTerm: String,
- index: Int): GeneratedExpression =
- inputType match {
- case ct: RowType =>
- val fieldType = ct.getTypeAt(index)
+ index: Int)
+ : GeneratedExpression = inputType.getTypeRoot match {
+ // ordered by type root definition
+ case ROW | STRUCTURED_TYPE =>
+ val fieldType = getFieldTypes(inputType).get(index)
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val readCode = rowFieldReadAccess(ctx, index.toString, inputTerm, fieldType)
@@ -667,6 +694,13 @@ object GenerateUtils {
}
GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)
+ case DISTINCT_TYPE =>
+ generateFieldAccess(
+ ctx,
+ inputType.asInstanceOf[DistinctType].getSourceType,
+ inputTerm,
+ index)
+
case _ =>
val fieldTypeTerm = boxedTypeTermForType(inputType)
val inputCode = s"($fieldTypeTerm) $inputTerm"
@@ -674,23 +708,30 @@ object GenerateUtils {
}
/**
- * Generates code for comparing two field.
+ * Generates code for comparing two fields.
*/
+ @tailrec
def generateCompare(
ctx: CodeGeneratorContext,
t: LogicalType,
nullsIsLast: Boolean,
leftTerm: String,
- rightTerm: String): String = t.getTypeRoot match {
- case BOOLEAN => s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))"
- case DATE | TIME_WITHOUT_TIME_ZONE =>
- s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)"
- case _ if PlannerTypeUtils.isPrimitive(t) =>
- s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)"
- case VARBINARY | BINARY =>
+ rightTerm: String)
+ : String = t.getTypeRoot match {
+ // ordered by type root definition
+ case CHAR | VARCHAR | DECIMAL | TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
+ s"$leftTerm.compareTo($rightTerm)"
+ case BOOLEAN =>
+ s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))"
+ case BINARY | VARBINARY =>
val sortUtil = classOf[org.apache.flink.table.runtime.operators.sort.SortUtil]
.getCanonicalName
s"$sortUtil.compareBinary($leftTerm, $rightTerm)"
+ case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | DATE | TIME_WITHOUT_TIME_ZONE |
+ INTERVAL_YEAR_MONTH | INTERVAL_DAY_TIME =>
+ s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)"
+ case TIMESTAMP_WITH_TIME_ZONE | MULTISET | MAP =>
+ throw new UnsupportedOperationException() // TODO support MULTISET and MAP?
case ARRAY =>
val at = t.asInstanceOf[ArrayType]
val compareFunc = newName("compareArray")
@@ -706,13 +747,13 @@ object GenerateUtils {
"""
ctx.addReusableMember(funcCode)
s"$compareFunc($leftTerm, $rightTerm)"
- case ROW =>
- val rowType = t.asInstanceOf[RowType]
- val orders = (0 until rowType.getFieldCount).map(_ => true).toArray
+ case ROW | STRUCTURED_TYPE =>
+ val fieldCount = getFieldCount(t)
+ val orders = (0 until fieldCount).map(_ => true).toArray
val comparisons = generateRowCompare(
ctx,
- (0 until rowType.getFieldCount).toArray,
- rowType.getChildren.toArray(Array[LogicalType]()),
+ (0 until fieldCount).toArray,
+ getFieldTypes(t).toArray(Array[LogicalType]()),
orders,
SortUtil.getNullDefaultOrders(orders),
"a",
@@ -727,18 +768,38 @@ object GenerateUtils {
"""
ctx.addReusableMember(funcCode)
s"$compareFunc($leftTerm, $rightTerm)"
+ case DISTINCT_TYPE =>
+ generateCompare(
+ ctx,
+ t.asInstanceOf[DistinctType].getSourceType,
+ nullsIsLast,
+ leftTerm,
+ rightTerm)
case RAW =>
- val rawType = t.asInstanceOf[TypeInformationRawType[_]]
- val ser = ctx.addReusableObject(
- rawType.getTypeInformation.createSerializer(new ExecutionConfig), "serializer")
- val comp = ctx.addReusableObject(
- rawType.getTypeInformation.asInstanceOf[AtomicTypeInfo[_]]
- .createComparator(true, new ExecutionConfig),
- "comparator")
- s"""
- |$comp.compare($leftTerm.toObject($ser), $rightTerm.toObject($ser))
- """.stripMargin
- case other => s"$leftTerm.compareTo($rightTerm)"
+ t match {
+ case rawType: RawType[_] =>
+ val clazz = rawType.getOriginatingClass
+ if (!classOf[Comparable[_]].isAssignableFrom(clazz)) {
+ throw new CodeGenException(
+ s"Raw type class '$clazz' must implement ${className[Comparable[_]]} to be used " +
+ s"in a comparision of two '${rawType.asSummaryString()}' types.")
+ }
+ val serializer = rawType.getTypeSerializer
+ val serializerTerm = ctx.addReusableObject(serializer, "serializer")
+ s"((${className[Comparable[_]]}) $leftTerm.toObject($serializerTerm))" +
+ s".compareTo($rightTerm.toObject($serializerTerm))"
+
+ case rawType: TypeInformationRawType[_] =>
+ val serializer = rawType.getTypeInformation.createSerializer(new ExecutionConfig)
+ val ser = ctx.addReusableObject(serializer, "serializer")
+ val comp = ctx.addReusableObject(
+ rawType.getTypeInformation.asInstanceOf[AtomicTypeInfo[_]]
+ .createComparator(true, new ExecutionConfig),
+ "comparator")
+ s"$comp.compare($leftTerm.toObject($ser), $rightTerm.toObject($ser))"
+ }
+ case NULL | SYMBOL | UNRESOLVED =>
+ throw new IllegalArgumentException("Illegal type: " + t)
}
/**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
index 7493aa2..1fbacc0 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
@@ -18,7 +18,6 @@
package org.apache.flink.table.planner.codegen.agg.batch
-import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.runtime.util.SingleElementIterator
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.table.data.{GenericRowData, RowData}
@@ -39,12 +38,13 @@ import org.apache.flink.table.runtime.types.InternalSerializers
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.{fromDataTypeToLogicalType, fromLogicalTypeToDataType}
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
-import org.apache.flink.table.types.logical.{LogicalType, RowType}
-
+import org.apache.flink.table.types.logical.{DistinctType, LogicalType, RowType}
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
+import scala.annotation.tailrec
+
/**
* Batch aggregate code generate helper.
*/
@@ -360,16 +360,7 @@ object AggCodeGenHelper {
aggBufferExprs.zip(initAggBufferExprs).map {
case (aggBufVar, initExpr) =>
- val resultCode = aggBufVar.resultType.getTypeRoot match {
- case VARCHAR | CHAR | ROW | ARRAY | MULTISET | MAP =>
- val serializer = InternalSerializers.create(
- aggBufVar.resultType, new ExecutionConfig)
- val term = ctx.addReusableObject(
- serializer, "serializer", serializer.getClass.getCanonicalName)
- val typeTerm = boxedTypeTermForType(aggBufVar.resultType)
- s"($typeTerm) $term.copy(${initExpr.resultTerm})"
- case _ => initExpr.resultTerm
- }
+ val resultCode = genElementCopyTerm(ctx, aggBufVar.resultType, initExpr.resultTerm)
s"""
|${initExpr.code}
|${aggBufVar.nullTerm} = ${initExpr.nullTerm};
@@ -378,6 +369,23 @@ object AggCodeGenHelper {
} mkString "\n"
}
+ @tailrec
+ private def genElementCopyTerm(
+ ctx: CodeGeneratorContext,
+ t: LogicalType,
+ inputTerm: String)
+ : String = t.getTypeRoot match {
+ case CHAR | VARCHAR | ARRAY | MULTISET | MAP | ROW | STRUCTURED_TYPE =>
+ val serializer = InternalSerializers.create(t)
+ val term = ctx.addReusableObject(
+ serializer, "serializer", serializer.getClass.getCanonicalName)
+ val typeTerm = boxedTypeTermForType(t)
+ s"($typeTerm) $term.copy($inputTerm)"
+ case DISTINCT_TYPE =>
+ genElementCopyTerm(ctx, t.asInstanceOf[DistinctType].getSourceType, inputTerm)
+ case _ => inputTerm
+ }
+
private[flink] def genAggregateByFlatAggregateBuffer(
isMerge: Boolean,
ctx: CodeGeneratorContext,
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
index 319021d..9a58882 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
@@ -18,6 +18,7 @@
package org.apache.flink.table.runtime.typeutils;
+import org.apache.flink.table.types.logical.DistinctType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeFamily;
import org.apache.flink.table.types.logical.TimestampKind;
@@ -60,9 +61,10 @@ public class TypeCheckUtils {
}
public static boolean isTimeInterval(LogicalType type) {
+ // ordered by type root definition
switch (type.getTypeRoot()) {
- case INTERVAL_DAY_TIME:
case INTERVAL_YEAR_MONTH:
+ case INTERVAL_DAY_TIME:
return true;
default:
return false;
@@ -122,22 +124,28 @@ public class TypeCheckUtils {
}
public static boolean isMutable(LogicalType type) {
- // the internal representation of String is StringData which is mutable
+ // ordered by type root definition
switch (type.getTypeRoot()) {
- case VARCHAR:
case CHAR:
+ case VARCHAR: // the internal representation of String is StringData which is mutable
case ARRAY:
case MULTISET:
case MAP:
case ROW:
+ case STRUCTURED_TYPE:
case RAW:
return true;
+ case TIMESTAMP_WITH_TIME_ZONE:
+ throw new UnsupportedOperationException("Unsupported type: " + type);
+ case DISTINCT_TYPE:
+ return isMutable(((DistinctType) type).getSourceType());
default:
return false;
}
}
public static boolean isReference(LogicalType type) {
+ // ordered by type root definition
switch (type.getTypeRoot()) {
case BOOLEAN:
case TINYINT:
@@ -153,6 +161,10 @@ public class TypeCheckUtils {
case INTERVAL_YEAR_MONTH:
case INTERVAL_DAY_TIME:
return false;
+ case TIMESTAMP_WITH_TIME_ZONE:
+ throw new UnsupportedOperationException("Unsupported type: " + type);
+ case DISTINCT_TYPE:
+ return isReference(((DistinctType) type).getSourceType());
default:
return true;
}