You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2016/12/07 18:00:19 UTC
flink git commit: [FLINK-4554] [table] Add support for array types
Repository: flink
Updated Branches:
refs/heads/master 13150a4ba -> 441400855
[FLINK-4554] [table] Add support for array types
This closes #2919.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/44140085
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/44140085
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/44140085
Branch: refs/heads/master
Commit: 4414008551b2843d98b3caddbb171fa1934e1f40
Parents: 13150a4
Author: twalthr <tw...@apache.org>
Authored: Fri Sep 23 16:44:42 2016 +0200
Committer: twalthr <tw...@apache.org>
Committed: Wed Dec 7 18:57:29 2016 +0100
----------------------------------------------------------------------
docs/dev/table_api.md | 158 +++++++-
.../flink/api/scala/table/expressionDsl.scala | 50 ++-
.../flink/api/table/FlinkTypeFactory.scala | 22 +-
.../flink/api/table/codegen/CodeGenUtils.scala | 11 +-
.../flink/api/table/codegen/CodeGenerator.scala | 37 ++
.../api/table/codegen/ExpressionReducer.scala | 4 +-
.../table/codegen/calls/ScalarOperators.scala | 198 +++++++++-
.../table/expressions/ExpressionParser.scala | 9 +-
.../api/table/expressions/ExpressionUtils.scala | 61 +++-
.../flink/api/table/expressions/array.scala | 146 ++++++++
.../api/table/expressions/comparison.scala | 3 -
.../api/table/plan/ProjectionTranslator.scala | 8 +-
.../table/plan/schema/ArrayRelDataType.scala | 53 +++
.../api/table/typeutils/TypeCheckUtils.scala | 14 +-
.../api/table/validate/FunctionCatalog.scala | 12 +-
.../src/test/resources/log4j-test.properties | 2 +-
.../api/table/expressions/ArrayTypeTest.scala | 359 +++++++++++++++++++
.../table/expressions/SqlExpressionTest.scala | 14 +-
18 files changed, 1122 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 6cf0dee..2b42ab2 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1470,7 +1470,14 @@ The Table API is built on top of Flink's DataSet and DataStream API. Internally,
| `Types.INTERVAL_MONTHS`| `INTERVAL YEAR TO MONTH` | `java.lang.Integer` |
| `Types.INTERVAL_MILLIS`| `INTERVAL DAY TO SECOND(3)` | `java.lang.Long` |
-Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and arrays can be fields of a row. Generic types and arrays are treated as a black box within Table API and SQL yet. Composite types, however, are fully supported types where fields of a composite type can be accessed using the `.get()` operator in Table API and dot operator (e.g. `MyTable.pojoColumn.myField`) in SQL. Composite types can also be flattened using `.flatten()` in Table API or `MyTable.pojoColumn.*` in SQL.
+
+Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row.
+
+Generic types are treated as a black box within Table API and SQL yet.
+
+Composite types, however, are fully supported types where fields of a composite type can be accessed using the `.get()` operator in Table API and dot operator (e.g. `MyTable.pojoColumn.myField`) in SQL. Composite types can also be flattened using `.flatten()` in Table API or `MyTable.pojoColumn.*` in SQL.
+
+Array types can be accessed using the `myArray.at(1)` operator in Table API and `myArray[1]` operator in SQL. Array literals can be created using `array(1, 2, 3)` in Table API and `ARRAY[1, 2, 3]` in SQL.
{% top %}
@@ -2038,6 +2045,50 @@ COMPOSITE.get(INT)
</td>
</tr>
+ <tr>
+ <td>
+ {% highlight java %}
+ARRAY.at(INT)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the element at a particular position in an array. The index starts at 1.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight java %}
+array(ANY [, ANY ]*)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Creates an array from a list of values. The array will be an array of objects (not primitives).</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight java %}
+ARRAY.cardinality()
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the number of elements of an array.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight scala %}
+ARRAY.element()
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the sole element of an array with a single element. Returns <code>null</code> if the array is empty. Throws an exception if the array has more than one element.</p>
+ </td>
+ </tr>
+
</tbody>
</table>
@@ -2599,6 +2650,50 @@ COMPOSITE.get(INT)
</td>
</tr>
+ <tr>
+ <td>
+ {% highlight scala %}
+ARRAY.at(INT)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the element at a particular position in an array. The index starts at 1.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight scala %}
+array(ANY [, ANY ]*)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Creates an array from a list of values. The array will be an array of objects (not primitives).</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight scala %}
+ARRAY.cardinality()
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the number of elements of an array.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight scala %}
+ARRAY.element()
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the sole element of an array with a single element. Returns <code>null</code> if the array is empty. Throws an exception if the array has more than one element.</p>
+ </td>
+ </tr>
+
</tbody>
</table>
</div>
@@ -3368,8 +3463,6 @@ CAST(value AS type)
</tbody>
</table>
-
-<!-- Disabled temporarily in favor of composite type support
<table class="table table-bordered">
<thead>
<tr>
@@ -3379,6 +3472,7 @@ CAST(value AS type)
</thead>
<tbody>
+ <!-- Disabled temporarily in favor of composite type support
<tr>
<td>
{% highlight text %}
@@ -3400,9 +3494,32 @@ ROW (value [, value]* )
<p>Creates a row from a list of values.</p>
</td>
</tr>
+-->
+
+ <tr>
+ <td>
+ {% highlight text %}
+array \u2018[\u2019 index \u2018]\u2019
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the element at a particular position in an array. The index starts at 1.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight text %}
+ARRAY \u2018[\u2019 value [, value ]* \u2018]\u2019
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Creates an array from a list of values.</p>
+ </td>
+ </tr>
+
</tbody>
</table>
--->
<table class="table table-bordered">
<thead>
@@ -3657,6 +3774,39 @@ tableName.compositeType.*
</tbody>
</table>
+<table class="table table-bordered">
+ <thead>
+ <tr>
+ <th class="text-left" style="width: 40%">Array functions</th>
+ <th class="text-center">Description</th>
+ </tr>
+ </thead>
+
+ <tbody>
+ <tr>
+ <td>
+ {% highlight text %}
+CARDINALITY(ARRAY)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the number of elements of an array.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight text %}
+ELEMENT(ARRAY)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns the sole element of an array with a single element. Returns <code>null</code> if the array is empty. Throws an exception if the array has more than one element.</p>
+ </td>
+ </tr>
+ </tbody>
+</table>
+
</div>
</div>
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
index 175ce2e..823458a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
@@ -21,9 +21,10 @@ import java.sql.{Date, Time, Timestamp}
import org.apache.calcite.avatica.util.DateTimeUtils._
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval}
+import org.apache.flink.api.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.api.table.expressions._
+import java.math.{BigDecimal => JBigDecimal}
import scala.language.implicitConversions
@@ -461,6 +462,29 @@ trait ImplicitExpressionOperations {
* into a flat representation where every subtype is a separate field.
*/
def flatten() = Flattening(expr)
+
+ /**
+ * Accesses the element of an array based on an index (starting at 1).
+ *
+ * @param index position of the element (starting at 1)
+ * @return value of the element
+ */
+ def at(index: Expression) = ArrayElementAt(expr, index)
+
+ /**
+ * Returns the number of elements of an array.
+ *
+ * @return number of elements
+ */
+ def cardinality() = ArrayCardinality(expr)
+
+ /**
+ * Returns the sole element of an array with a single element. Returns null if the array is
+ * empty. Throws an exception if the array has more than one element.
+ *
+ * @return the first and only element of an array with a single element
+ */
+ def element() = ArrayElement(expr)
}
/**
@@ -540,18 +564,24 @@ trait ImplicitExpressionConversions {
implicit def float2Literal(d: Float): Expression = Literal(d)
implicit def string2Literal(str: String): Expression = Literal(str)
implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool)
- implicit def javaDec2Literal(javaDec: java.math.BigDecimal): Expression = Literal(javaDec)
- implicit def scalaDec2Literal(scalaDec: scala.math.BigDecimal): Expression =
+ implicit def javaDec2Literal(javaDec: JBigDecimal): Expression = Literal(javaDec)
+ implicit def scalaDec2Literal(scalaDec: BigDecimal): Expression =
Literal(scalaDec.bigDecimal)
implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate)
implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime)
- implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp)
+ implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression =
+ Literal(sqlTimestamp)
+ implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
}
// ------------------------------------------------------------------------------------------------
// Expressions with no parameters
// ------------------------------------------------------------------------------------------------
+// we disable the object checker here as it checks for capital letters of objects
+// but we want that objects look like functions in certain cases e.g. array(1, 2, 3)
+// scalastyle:off object.name
+
/**
* Returns the current SQL date in UTC time zone.
*/
@@ -645,5 +675,17 @@ object temporalOverlaps {
}
}
+/**
+ * Creates an array of literals. The array will be an array of objects (not primitives).
+ */
+object array {
+ /**
+ * Creates an array of literals. The array will be an array of objects (not primitives).
+ */
+ def apply(head: Expression, tail: Expression*): Expression = {
+ ArrayConstructor(head +: tail.toSeq)
+ }
+}
+// scalastyle:on object.name
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
index bb11576..8dcd660 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
@@ -26,11 +26,12 @@ import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
import org.apache.flink.api.table.FlinkTypeFactory.typeInfoToSqlTypeName
-import org.apache.flink.api.table.plan.schema.{CompositeRelDataType, GenericRelDataType}
+import org.apache.flink.api.table.plan.schema.{ArrayRelDataType, CompositeRelDataType, GenericRelDataType}
import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.api.table.typeutils.TypeCheckUtils.isSimple
@@ -102,11 +103,22 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
}
}
+ override def createArrayType(elementType: RelDataType, maxCardinality: Long): RelDataType =
+ new ArrayRelDataType(
+ ObjectArrayTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)),
+ elementType,
+ true)
+
private def createAdvancedType(typeInfo: TypeInformation[_]): RelDataType = typeInfo match {
case ct: CompositeType[_] =>
new CompositeRelDataType(ct, this)
- // TODO add specific RelDataTypes for PrimitiveArrayTypeInfo, ObjectArrayTypeInfo
+ case pa: PrimitiveArrayTypeInfo[_] =>
+ new ArrayRelDataType(pa, createTypeFromTypeInfo(pa.getComponentType), false)
+
+ case oa: ObjectArrayTypeInfo[_, _] =>
+ new ArrayRelDataType(oa, createTypeFromTypeInfo(oa.getComponentInfo), true)
+
case ti: TypeInformation[_] =>
new GenericRelDataType(typeInfo, getTypeSystem.asInstanceOf[FlinkTypeSystem])
@@ -190,6 +202,10 @@ object FlinkTypeFactory {
// ROW and CURSOR for UDTF case, whose type info will never be used, just a placeholder
case ROW | CURSOR => new NothingTypeInfo
+ case ARRAY if relDataType.isInstanceOf[ArrayRelDataType] =>
+ val arrayRelDataType = relDataType.asInstanceOf[ArrayRelDataType]
+ arrayRelDataType.typeInfo
+
case _@t =>
throw TableException(s"Type is not supported: $t")
}
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala
index b78012c..4092a24 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala
@@ -155,7 +155,6 @@ object CodeGenUtils {
def enumValueOf[T <: Enum[T]](cls: Class[_], stringValue: String): Enum[_] =
Enum.valueOf(cls.asInstanceOf[Class[T]], stringValue).asInstanceOf[Enum[_]]
-
// ----------------------------------------------------------------------------------------------
def requireNumeric(genExpr: GeneratedExpression) =
@@ -189,6 +188,16 @@ object CodeGenUtils {
throw new CodeGenException("Interval expression type expected.")
}
+ def requireArray(genExpr: GeneratedExpression) =
+ if (!TypeCheckUtils.isArray(genExpr.resultType)) {
+ throw new CodeGenException("Array expression type expected.")
+ }
+
+ def requireInteger(genExpr: GeneratedExpression) =
+ if (!TypeCheckUtils.isInteger(genExpr.resultType)) {
+ throw new CodeGenException("Integer expression type expected.")
+ }
+
// ----------------------------------------------------------------------------------------------
def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
index f7d6863..7caad12 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
@@ -976,6 +976,27 @@ class CodeGenerator(
requireString(left)
generateArithmeticOperator("+", nullCheck, resultType, left, right)
+ // arrays
+ case ARRAY_VALUE_CONSTRUCTOR =>
+ generateArray(this, resultType, operands)
+
+ case ITEM =>
+ val array = operands.head
+ val index = operands(1)
+ requireArray(array)
+ requireInteger(index)
+ generateArrayElementAt(this, array, index)
+
+ case CARDINALITY =>
+ val array = operands.head
+ requireArray(array)
+ generateArrayCardinality(nullCheck, array)
+
+ case ELEMENT =>
+ val array = operands.head
+ requireArray(array)
+ generateArrayElement(this, array)
+
// advanced scalar functions
case sqlOperator: SqlOperator =>
val callGen = FunctionGenerator.getCallGenerator(
@@ -1394,6 +1415,22 @@ class CodeGenerator(
}
/**
+ * Adds a reusable array to the member area of the generated [[Function]].
+ */
+ def addReusableArray(clazz: Class[_], size: Int): String = {
+ val fieldTerm = newName("array")
+ val classQualifier = clazz.getCanonicalName // works also for int[] etc.
+ val initArray = classQualifier.replaceFirst("\\[", s"[$size")
+ val fieldArray =
+ s"""
+ |transient $classQualifier $fieldTerm =
+ | new $initArray;
+ |""".stripMargin
+ reusableMemberStatements.add(fieldArray)
+ fieldTerm
+ }
+
+ /**
* Adds a reusable timestamp to the beginning of the SAM of the generated [[Function]].
*/
def addReusableTimestamp(): String = {
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala
index 74756ef..731452f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala
@@ -63,7 +63,7 @@ class ExpressionReducer(config: TableConfig)
)
// we don't support object literals yet, we skip those constant expressions
- case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) => None
+ case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) | (SqlTypeName.ARRAY, _) => None
case (_, e) => Some(e)
}
@@ -101,7 +101,7 @@ class ExpressionReducer(config: TableConfig)
val unreduced = constExprs.get(i)
unreduced.getType.getSqlTypeName match {
// we insert the original expression for object literals
- case SqlTypeName.ANY | SqlTypeName.ROW =>
+ case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY =>
reducedValues.add(unreduced)
case _ =>
val literal = rexBuilder.makeLiteral(
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
index 75c0149..330e2fe 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
@@ -21,9 +21,10 @@ import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY
import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange}
import org.apache.calcite.util.BuiltInMethod
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.api.table.codegen.CodeGenUtils._
-import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression}
+import org.apache.flink.api.table.codegen.{CodeGenerator, CodeGenException, GeneratedExpression}
import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.api.table.typeutils.TypeCheckUtils._
@@ -91,6 +92,12 @@ object ScalarOperators {
else if (isTemporal(left.resultType) && left.resultType == right.resultType) {
generateComparison("==", nullCheck, left, right)
}
+ // array types
+ else if (isArray(left.resultType) && left.resultType == right.resultType) {
+ generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
+ (leftTerm, rightTerm) => s"java.util.Arrays.equals($leftTerm, $rightTerm)"
+ }
+ }
// comparable types of same type
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
generateComparison("==", nullCheck, left, right)
@@ -125,6 +132,12 @@ object ScalarOperators {
else if (isTemporal(left.resultType) && left.resultType == right.resultType) {
generateComparison("!=", nullCheck, left, right)
}
+ // array types
+ else if (isArray(left.resultType) && left.resultType == right.resultType) {
+ generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
+ (leftTerm, rightTerm) => s"!java.util.Arrays.equals($leftTerm, $rightTerm)"
+ }
+ }
// comparable types
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
generateComparison("!=", nullCheck, left, right)
@@ -428,7 +441,7 @@ object ScalarOperators {
// Date/Time/Timestamp -> String
case (dtt: SqlTimeTypeInfo[_], STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
- (operandTerm) => s"""${internalToTimePointCode(dtt, operandTerm)}.toString()"""
+ (operandTerm) => s"${internalToTimePointCode(dtt, operandTerm)}.toString()"
}
// Interval Months -> String
@@ -447,6 +460,18 @@ object ScalarOperators {
(operandTerm) => s"$method($operandTerm, $timeUnitRange, 3)" // milli second precision
}
+ // Object array -> String
+ case (_:ObjectArrayTypeInfo[_, _], STRING_TYPE_INFO) =>
+ generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
+ (operandTerm) => s"java.util.Arrays.deepToString($operandTerm)"
+ }
+
+ // Primitive array -> String
+ case (_:PrimitiveArrayTypeInfo[_], STRING_TYPE_INFO) =>
+ generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
+ (operandTerm) => s"java.util.Arrays.toString($operandTerm)"
+ }
+
// * (not Date/Time/Timestamp) -> String
case (_, STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
@@ -701,6 +726,173 @@ object ScalarOperators {
generateUnaryArithmeticOperator(operator, nullCheck, operand.resultType, operand)
}
+ def generateArray(
+ codeGenerator: CodeGenerator,
+ resultType: TypeInformation[_],
+ elements: Seq[GeneratedExpression])
+ : GeneratedExpression = {
+ val arrayTerm = codeGenerator.addReusableArray(resultType.getTypeClass, elements.size)
+
+ val boxedElements: Seq[GeneratedExpression] = resultType match {
+
+ case oati: ObjectArrayTypeInfo[_, _] =>
+ // we box the elements to also represent null values
+ val boxedTypeTerm = boxedTypeTermForTypeInfo(oati.getComponentInfo)
+
+ elements.map { e =>
+ val boxedExpr = codeGenerator.generateOutputFieldBoxing(e)
+ val exprOrNull: String = if (codeGenerator.nullCheck) {
+ s"${boxedExpr.nullTerm} ? null : ($boxedTypeTerm) ${boxedExpr.resultTerm}"
+ } else {
+ boxedExpr.resultTerm
+ }
+ boxedExpr.copy(resultTerm = exprOrNull)
+ }
+
+ // no boxing necessary
+ case _: PrimitiveArrayTypeInfo[_] => elements
+ }
+
+ val code = boxedElements
+ .zipWithIndex
+ .map { case (element, idx) =>
+ s"""
+ |${element.code}
+ |$arrayTerm[$idx] = ${element.resultTerm};
+ |""".stripMargin
+ }
+ .mkString("\n")
+
+ GeneratedExpression(arrayTerm, GeneratedExpression.NEVER_NULL, code, resultType)
+ }
+
+ def generateArrayElementAt(
+ codeGenerator: CodeGenerator,
+ array: GeneratedExpression,
+ index: GeneratedExpression)
+ : GeneratedExpression = {
+
+ val resultTerm = newName("result")
+
+ array.resultType match {
+
+ // unbox object array types
+ case oati: ObjectArrayTypeInfo[_, _] =>
+ // get boxed array element
+ val resultTypeTerm = boxedTypeTermForTypeInfo(oati.getComponentInfo)
+
+ val arrayAccessCode = if (codeGenerator.nullCheck) {
+ s"""
+ |${array.code}
+ |${index.code}
+ |$resultTypeTerm $resultTerm = (${array.nullTerm} || ${index.nullTerm}) ?
+ | null : ${array.resultTerm}[${index.resultTerm} - 1];
+ |""".stripMargin
+ } else {
+ s"""
+ |${array.code}
+ |${index.code}
+ |$resultTypeTerm $resultTerm = ${array.resultTerm}[${index.resultTerm} - 1];
+ |""".stripMargin
+ }
+
+ // generate unbox code
+ val unboxing = codeGenerator.generateInputFieldUnboxing(oati.getComponentInfo, resultTerm)
+
+ unboxing.copy(code =
+ s"""
+ |$arrayAccessCode
+ |${unboxing.code}
+ |""".stripMargin
+ )
+
+ // no unboxing necessary
+ case pati: PrimitiveArrayTypeInfo[_] =>
+ generateOperatorIfNotNull(codeGenerator.nullCheck, pati.getComponentType, array, index) {
+ (leftTerm, rightTerm) => s"$leftTerm[$rightTerm - 1]"
+ }
+ }
+ }
+
+ def generateArrayElement(
+ codeGenerator: CodeGenerator,
+ array: GeneratedExpression)
+ : GeneratedExpression = {
+
+ val nullTerm = newName("isNull")
+ val resultTerm = newName("result")
+ val resultType = array.resultType match {
+ case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
+ case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
+ }
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
+ val defaultValue = primitiveDefaultValue(resultType)
+
+ val arrayLengthCode = if (codeGenerator.nullCheck) {
+ s"${array.nullTerm} ? 0 : ${array.resultTerm}.length"
+ } else {
+ s"${array.resultTerm}.length"
+ }
+
+ val arrayAccessCode = array.resultType match {
+ case oati: ObjectArrayTypeInfo[_, _] =>
+ // generate unboxing code
+ val unboxing = codeGenerator.generateInputFieldUnboxing(
+ oati.getComponentInfo,
+ s"${array.resultTerm}[0]")
+
+ s"""
+ |${array.code}
+ |${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" }
+ |$resultTypeTerm $resultTerm;
+ |switch ($arrayLengthCode) {
+ | case 0:
+ | ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" }
+ | $resultTerm = $defaultValue;
+ | break;
+ | case 1:
+ | ${unboxing.code}
+ | ${if (codeGenerator.nullCheck) s"$nullTerm = ${unboxing.nullTerm};" else "" }
+ | $resultTerm = ${unboxing.resultTerm};
+ | break;
+ | default:
+ | throw new RuntimeException("Array has more than one element.");
+ |}
+ |""".stripMargin
+
+ case pati: PrimitiveArrayTypeInfo[_] =>
+ s"""
+ |${array.code}
+ |${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" }
+ |$resultTypeTerm $resultTerm;
+ |switch ($arrayLengthCode) {
+ | case 0:
+ | ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" }
+ | $resultTerm = $defaultValue;
+ | break;
+ | case 1:
+ | ${if (codeGenerator.nullCheck) s"$nullTerm = false;" else "" }
+ | $resultTerm = ${array.resultTerm}[0];
+ | break;
+ | default:
+ | throw new RuntimeException("Array has more than one element.");
+ |}
+ |""".stripMargin
+ }
+
+ GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, resultType)
+ }
+
+ def generateArrayCardinality(
+ nullCheck: Boolean,
+ array: GeneratedExpression)
+ : GeneratedExpression = {
+
+ generateUnaryOperatorIfNotNull(nullCheck, INT_TYPE_INFO, array) {
+ (operandTerm) => s"${array.resultTerm}.length"
+ }
+ }
+
// ----------------------------------------------------------------------------------------------
private def generateUnaryOperatorIfNotNull(
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
index a926717..c960a79 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
@@ -48,6 +48,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
// Keyword
+ lazy val ARRAY: Keyword = Keyword("Array")
lazy val AS: Keyword = Keyword("as")
lazy val COUNT: Keyword = Keyword("count")
lazy val AVG: Keyword = Keyword("avg")
@@ -88,7 +89,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val FLATTEN: Keyword = Keyword("flatten")
def functionIdent: ExpressionParser.Parser[String] =
- not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~
+ not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~
not(SUM) ~ not(START) ~ not(END)~ not(CAST) ~ not(NULL) ~
not(IF) ~> super.ident
@@ -298,6 +299,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
// prefix operators
+ lazy val prefixArray: PackratParser[Expression] =
+ ARRAY ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { elements => ArrayConstructor(elements) }
+
lazy val prefixSum: PackratParser[Expression] =
SUM ~ "(" ~> expression <~ ")" ^^ { e => Sum(e) }
@@ -372,7 +376,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
FLATTEN ~ "(" ~> composite <~ ")" ^^ { e => Flattening(e) }
lazy val prefixed: PackratParser[Expression] =
- prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | prefixStart | prefixEnd |
+ prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg |
+ prefixStart | prefixEnd |
prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract |
prefixFloor | prefixCeil | prefixGet | prefixFlattening |
prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala
index c071c59..8657534 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala
@@ -18,13 +18,16 @@
package org.apache.flink.api.table.expressions
-import java.math.BigDecimal
+import java.lang.{Boolean => JBoolean, Byte => JByte, Short => JShort, Integer => JInteger, Long => JLong, Float => JFloat, Double => JDouble}
+import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date, Time, Timestamp}
import org.apache.calcite.avatica.util.TimeUnit
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.{RexBuilder, RexNode}
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.api.table.ValidationException
import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo}
object ExpressionUtils {
@@ -54,6 +57,48 @@ object ExpressionUtils {
throw new IllegalArgumentException("Invalid value for row interval literal.")
}
+ private[flink] def convertArray(array: Array[_]): Expression = {
+ def createArray(): Expression = {
+ ArrayConstructor(array.map(Literal(_)))
+ }
+
+ array match {
+ // primitives
+ case _: Array[Boolean] => createArray()
+ case _: Array[Byte] => createArray()
+ case _: Array[Short] => createArray()
+ case _: Array[Int] => createArray()
+ case _: Array[Long] => createArray()
+ case _: Array[Float] => createArray()
+ case _: Array[Double] => createArray()
+
+ // boxed types
+ case _: Array[JBoolean] => createArray()
+ case _: Array[JByte] => createArray()
+ case _: Array[JShort] => createArray()
+ case _: Array[JInteger] => createArray()
+ case _: Array[JLong] => createArray()
+ case _: Array[JFloat] => createArray()
+ case _: Array[JDouble] => createArray()
+
+ // others
+ case _: Array[String] => createArray()
+ case _: Array[JBigDecimal] => createArray()
+ case _: Array[Date] => createArray()
+ case _: Array[Time] => createArray()
+ case _: Array[Timestamp] => createArray()
+ case bda: Array[BigDecimal] => ArrayConstructor(bda.map { bd => Literal(bd.bigDecimal) })
+
+ case _ =>
+ // nested
+ if (array.length > 0 && array.head.isInstanceOf[Array[_]]) {
+ ArrayConstructor(array.map { na => convertArray(na.asInstanceOf[Array[_]]) })
+ } else {
+ throw ValidationException("Unsupported array type.")
+ }
+ }
+ }
+
// ----------------------------------------------------------------------------------------------
// RexNode conversion functions (see org.apache.calcite.sql2rel.StandardConvertletTable)
// ----------------------------------------------------------------------------------------------
@@ -61,7 +106,7 @@ object ExpressionUtils {
/**
* Copy of [[org.apache.calcite.sql2rel.StandardConvertletTable#getFactor()]].
*/
- private[flink] def getFactor(unit: TimeUnit): BigDecimal = unit match {
+ private[flink] def getFactor(unit: TimeUnit): JBigDecimal = unit match {
case TimeUnit.DAY => java.math.BigDecimal.ONE
case TimeUnit.HOUR => TimeUnit.DAY.multiplier
case TimeUnit.MINUTE => TimeUnit.HOUR.multiplier
@@ -78,20 +123,20 @@ object ExpressionUtils {
rexBuilder: RexBuilder,
resType: RelDataType,
res: RexNode,
- value: BigDecimal)
+ value: JBigDecimal)
: RexNode = {
- if (value == BigDecimal.ONE) return res
+ if (value == JBigDecimal.ONE) return res
rexBuilder.makeCall(SqlStdOperatorTable.MOD, res, rexBuilder.makeExactLiteral(value, resType))
}
/**
* Copy of [[org.apache.calcite.sql2rel.StandardConvertletTable#divide()]].
*/
- private[flink] def divide(rexBuilder: RexBuilder, res: RexNode, value: BigDecimal): RexNode = {
- if (value == BigDecimal.ONE) return res
- if (value.compareTo(BigDecimal.ONE) < 0 && value.signum == 1) {
+ private[flink] def divide(rexBuilder: RexBuilder, res: RexNode, value: JBigDecimal): RexNode = {
+ if (value == JBigDecimal.ONE) return res
+ if (value.compareTo(JBigDecimal.ONE) < 0 && value.signum == 1) {
try {
- val reciprocal = BigDecimal.ONE.divide(value, BigDecimal.ROUND_UNNECESSARY)
+ val reciprocal = JBigDecimal.ONE.divide(value, JBigDecimal.ROUND_UNNECESSARY)
return rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY,
res,
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala
new file mode 100644
index 0000000..78084de
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.expressions
+
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.calcite.tools.RelBuilder
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
+import org.apache.flink.api.table.FlinkRelBuilder
+import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
+
+import scala.collection.JavaConverters._
+
+case class ArrayConstructor(elements: Seq[Expression]) extends Expression {
+
+ override private[flink] def children: Seq[Expression] = elements
+
+ override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ val relDataType = relBuilder
+ .asInstanceOf[FlinkRelBuilder]
+ .getTypeFactory
+ .createTypeFromTypeInfo(resultType)
+ val values = elements.map(_.toRexNode).toList.asJava
+ relBuilder
+ .getRexBuilder
+ .makeCall(relDataType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, values)
+ }
+
+ override def toString = s"array(${elements.mkString(", ")})"
+
+ override private[flink] def resultType = ObjectArrayTypeInfo.getInfoFor(elements.head.resultType)
+
+ override private[flink] def validateInput(): ValidationResult = {
+ if (elements.isEmpty) {
+ return ValidationFailure("Empty arrays are not supported yet.")
+ }
+ val elementType = elements.head.resultType
+ if (!elements.forall(_.resultType == elementType)) {
+ ValidationFailure("Not all elements of the array have the same type.")
+ } else {
+ ValidationSuccess
+ }
+ }
+}
+
+case class ArrayElementAt(array: Expression, index: Expression) extends Expression {
+
+ override private[flink] def children: Seq[Expression] = Seq(array, index)
+
+ override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ relBuilder
+ .getRexBuilder
+ .makeCall(SqlStdOperatorTable.ITEM, array.toRexNode, index.toRexNode)
+ }
+
+ override def toString = s"($array).at($index)"
+
+ override private[flink] def resultType = array.resultType match {
+ case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
+ case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
+ }
+
+ override private[flink] def validateInput(): ValidationResult = {
+ array.resultType match {
+ case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] =>
+ if (index.resultType == INT_TYPE_INFO) {
+ // check for common user mistake
+ index match {
+ case Literal(value: Int, INT_TYPE_INFO) if value < 1 =>
+ ValidationFailure(
+ s"Array element access needs an index starting at 1 but was $value.")
+ case _ => ValidationSuccess
+ }
+ } else {
+ ValidationFailure(
+ s"Array element access needs an integer index but was '${index.resultType}'.")
+ }
+ case other@_ => ValidationFailure(s"Array expected but was '$other'.")
+ }
+ }
+}
+
+case class ArrayCardinality(array: Expression) extends Expression {
+
+ override private[flink] def children: Seq[Expression] = Seq(array)
+
+ override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ relBuilder
+ .getRexBuilder
+ .makeCall(SqlStdOperatorTable.CARDINALITY, array.toRexNode)
+ }
+
+ override def toString = s"($array).cardinality()"
+
+ override private[flink] def resultType = BasicTypeInfo.INT_TYPE_INFO
+
+ override private[flink] def validateInput(): ValidationResult = {
+ array.resultType match {
+ case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess
+ case other@_ => ValidationFailure(s"Array expected but was '$other'.")
+ }
+ }
+}
+
+case class ArrayElement(array: Expression) extends Expression {
+
+ override private[flink] def children: Seq[Expression] = Seq(array)
+
+ override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ relBuilder
+ .getRexBuilder
+ .makeCall(SqlStdOperatorTable.ELEMENT, array.toRexNode)
+ }
+
+ override def toString = s"($array).element()"
+
+ override private[flink] def resultType = array.resultType match {
+ case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
+ case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
+ }
+
+ override private[flink] def validateInput(): ValidationResult = {
+ array.resultType match {
+ case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess
+ case other@_ => ValidationFailure(s"Array expected but was '$other'.")
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
index d5244d0..5a150f8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
@@ -36,7 +36,6 @@ abstract class BinaryComparison extends BinaryExpression {
override private[flink] def resultType = BOOLEAN_TYPE_INFO
- // TODO: tighten this rule once we implemented type coercion rules during validation
override private[flink] def validateInput(): ValidationResult =
(left.resultType, right.resultType) match {
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
@@ -56,7 +55,6 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
override private[flink] def validateInput(): ValidationResult =
(left.resultType, right.resultType) match {
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
- // TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) =>
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
@@ -71,7 +69,6 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari
override private[flink] def validateInput(): ValidationResult =
(left.resultType, right.resultType) match {
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
- // TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) =>
ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType")
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index c093f1a..22b77b4 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -143,7 +143,13 @@ object ProjectionTranslator {
case sfc @ ScalarFunctionCall(clazz, args) =>
val newArgs: Seq[Expression] = args
.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
- sfc.makeCopy(Array(clazz,newArgs))
+ sfc.makeCopy(Array(clazz, newArgs))
+
+ // array constructor
+ case c @ ArrayConstructor(args) =>
+ val newArgs = c.elements
+ .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ c.makeCopy(Array(newArgs))
// General expression
case e: Expression =>
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala
new file mode 100644
index 0000000..92fcb83
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.schema
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql.`type`.ArraySqlType
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
+/**
+ * Flink distinguishes between primitive arrays (int[], double[], ...) and
+ * object arrays (Integer[], MyPojo[], ...). This custom type supports both cases.
+ */
+class ArrayRelDataType(
+ val typeInfo: TypeInformation[_],
+ elementType: RelDataType,
+ isNullable: Boolean)
+ extends ArraySqlType(
+ elementType,
+ isNullable) {
+
+ override def toString = s"ARRAY($typeInfo)"
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[ArrayRelDataType]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ArrayRelDataType =>
+ super.equals(that) &&
+ (that canEqual this) &&
+ typeInfo == that.typeInfo
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ typeInfo.hashCode()
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala
index aa8614b..e30e273 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala
@@ -17,8 +17,9 @@
*/
package org.apache.flink.api.table.typeutils
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, STRING_TYPE_INFO}
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, NumericTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, INT_TYPE_INFO, STRING_TYPE_INFO}
+import org.apache.flink.api.common.typeinfo._
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.api.table.validate._
object TypeCheckUtils {
@@ -61,8 +62,15 @@ object TypeCheckUtils {
def isDecimal(dataType: TypeInformation[_]): Boolean = dataType == BIG_DEC_TYPE_INFO
+ def isInteger(dataType: TypeInformation[_]): Boolean = dataType == INT_TYPE_INFO
+
+ def isArray(dataType: TypeInformation[_]): Boolean = dataType match {
+ case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => true
+ case _ => false
+ }
+
def isComparable(dataType: TypeInformation[_]): Boolean =
- classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass)
+ classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass) && !isArray(dataType)
def assertNumericExpr(
dataType: TypeInformation[_],
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
index dc68b89..8e409cc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
@@ -185,7 +185,12 @@ object FunctionCatalog {
"localTime" -> classOf[LocalTime],
"localTimestamp" -> classOf[LocalTimestamp],
"quarter" -> classOf[Quarter],
- "temporalOverlaps" -> classOf[TemporalOverlaps]
+ "temporalOverlaps" -> classOf[TemporalOverlaps],
+
+ // array
+ "cardinality" -> classOf[ArrayCardinality],
+ "at" -> classOf[ArrayElementAt],
+ "element" -> classOf[ArrayElement]
// TODO implement function overloading here
// "floor" -> classOf[TemporalFloor]
@@ -258,6 +263,11 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG,
+ // ARRAY OPERATORS
+ SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ SqlStdOperatorTable.ITEM,
+ SqlStdOperatorTable.CARDINALITY,
+ SqlStdOperatorTable.ELEMENT,
// SPECIAL OPERATORS
SqlStdOperatorTable.ROW,
SqlStdOperatorTable.OVERLAPS,
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/test/resources/log4j-test.properties
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/resources/log4j-test.properties b/flink-libraries/flink-table/src/test/resources/log4j-test.properties
index f713aa8..4c74d85 100644
--- a/flink-libraries/flink-table/src/test/resources/log4j-test.properties
+++ b/flink-libraries/flink-table/src/test/resources/log4j-test.properties
@@ -18,7 +18,7 @@
# Set root logger level to OFF to not flood build logs
# set manually to INFO for debugging purposes
-log4j.rootLogger=INFO, testlogger
+log4j.rootLogger=OFF, testlogger
# A1 is set to be a ConsoleAppender.
log4j.appender.testlogger=org.apache.log4j.ConsoleAppender
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala
new file mode 100644
index 0000000..034ce0b
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.expressions
+
+import java.sql.Date
+
+import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils.ExpressionTestBase
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.{Row, Types, ValidationException}
+import org.junit.Test
+
+class ArrayTypeTest extends ExpressionTestBase {
+
+ @Test(expected = classOf[ValidationException])
+ def testObviousInvalidIndexTableApi(): Unit = {
+ testTableApi('f2.at(0), "FAIL", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testEmptyArraySql(): Unit = {
+ testSqlApi("ARRAY[]", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testEmptyArrayTableApi(): Unit = {
+ testTableApi("FAIL", "array()", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testNullArraySql(): Unit = {
+ testSqlApi("ARRAY[NULL]", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testDifferentTypesArraySql(): Unit = {
+ testSqlApi("ARRAY[1, TRUE]", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testDifferentTypesArrayTableApi(): Unit = {
+ testTableApi("FAIL", "array(1, true)", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testUnsupportedComparison(): Unit = {
+ testAllApis(
+ 'f2 <= 'f5.at(1),
+ "f2 <= f5.at(1)",
+ "f2 <= f5[1]",
+ "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testElementNonArray(): Unit = {
+ testTableApi(
+ 'f0.element(),
+ "FAIL",
+ "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testElementNonArraySql(): Unit = {
+ testSqlApi(
+ "ELEMENT(f0)",
+ "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testCardinalityOnNonArray(): Unit = {
+ testTableApi('f0.cardinality(), "FAIL", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testCardinalityOnNonArraySql(): Unit = {
+ testSqlApi("CARDINALITY(f0)", "FAIL")
+ }
+
+ @Test
+ def testArrayLiterals(): Unit = {
+ // primitive literals
+ testAllApis(array(1, 2, 3), "array(1, 2, 3)", "ARRAY[1, 2, 3]", "[1, 2, 3]")
+
+ testAllApis(
+ array(true, true, true),
+ "array(true, true, true)",
+ "ARRAY[TRUE, TRUE, TRUE]",
+ "[true, true, true]")
+
+ // object literals
+ testTableApi(array(BigDecimal(1), BigDecimal(1)), "array(1p, 1p)", "[1, 1]")
+
+ testAllApis(
+ array(array(array(1), array(1))),
+ "array(array(array(1), array(1)))",
+ "ARRAY[ARRAY[ARRAY[1], ARRAY[1]]]",
+ "[[[1], [1]]]")
+
+ testAllApis(
+ array(1 + 1, 3 * 3),
+ "array(1 + 1, 3 * 3)",
+ "ARRAY[1 + 1, 3 * 3]",
+ "[2, 9]")
+
+ testAllApis(
+ array(Null(Types.INT), 1),
+ "array(Null(INT), 1)",
+ "ARRAY[NULLIF(1,1), 1]",
+ "[null, 1]")
+
+ testAllApis(
+ array(array(Null(Types.INT), 1)),
+ "array(array(Null(INT), 1))",
+ "ARRAY[ARRAY[NULLIF(1,1), 1]]",
+ "[[null, 1]]")
+
+ // implicit conversion
+ testTableApi(
+ Array(1, 2, 3),
+ "array(1, 2, 3)",
+ "[1, 2, 3]")
+
+ testTableApi(
+ Array[Integer](1, 2, 3),
+ "array(1, 2, 3)",
+ "[1, 2, 3]")
+
+ testAllApis(
+ Array(Date.valueOf("1985-04-11")),
+ "array('1985-04-11'.toDate)",
+ "ARRAY[DATE '1985-04-11']",
+ "[1985-04-11]")
+
+ testAllApis(
+ Array(BigDecimal(2.0002), BigDecimal(2.0003)),
+ "Array(2.0002p, 2.0003p)",
+ "ARRAY[CAST(2.0002 AS DECIMAL), CAST(2.0003 AS DECIMAL)]",
+ "[2.0002, 2.0003]")
+
+ testAllApis(
+ Array(Array(x = true)),
+ "Array(Array(true))",
+ "ARRAY[ARRAY[TRUE]]",
+ "[[true]]")
+
+ testAllApis(
+ Array(Array(1, 2, 3), Array(3, 2, 1)),
+ "Array(Array(1, 2, 3), Array(3, 2, 1))",
+ "ARRAY[ARRAY[1, 2, 3], ARRAY[3, 2, 1]]",
+ "[[1, 2, 3], [3, 2, 1]]")
+ }
+
+ @Test
+ def testArrayField(): Unit = {
+ testAllApis(
+ array('f0, 'f1),
+ "array(f0, f1)",
+ "ARRAY[f0, f1]",
+ "[null, 42]")
+
+ testAllApis(
+ array('f0, 'f1),
+ "array(f0, f1)",
+ "ARRAY[f0, f1]",
+ "[null, 42]")
+
+ testAllApis(
+ 'f2,
+ "f2",
+ "f2",
+ "[1, 2, 3]")
+
+ testAllApis(
+ 'f3,
+ "f3",
+ "f3",
+ "[1984-03-12, 1984-02-10]")
+
+ testAllApis(
+ 'f5,
+ "f5",
+ "f5",
+ "[[1, 2, 3], null]")
+
+ testAllApis(
+ 'f6,
+ "f6",
+ "f6",
+ "[1, null, null, 4]")
+
+ testAllApis(
+ 'f2,
+ "f2",
+ "f2",
+ "[1, 2, 3]")
+
+ testAllApis(
+ 'f2.at(1),
+ "f2.at(1)",
+ "f2[1]",
+ "1")
+
+ testAllApis(
+ 'f3.at(1),
+ "f3.at(1)",
+ "f3[1]",
+ "1984-03-12")
+
+ testAllApis(
+ 'f3.at(2),
+ "f3.at(2)",
+ "f3[2]",
+ "1984-02-10")
+
+ testAllApis(
+ 'f5.at(1).at(2),
+ "f5.at(1).at(2)",
+ "f5[1][2]",
+ "2")
+
+ testAllApis(
+ 'f5.at(2).at(2),
+ "f5.at(2).at(2)",
+ "f5[2][2]",
+ "null")
+
+ testAllApis(
+ 'f4.at(2).at(2),
+ "f4.at(2).at(2)",
+ "f4[2][2]",
+ "null")
+ }
+
+ @Test
+ def testArrayOperations(): Unit = {
+ // cardinality
+ testAllApis(
+ 'f2.cardinality(),
+ "f2.cardinality()",
+ "CARDINALITY(f2)",
+ "3")
+
+ testAllApis(
+ 'f4.cardinality(),
+ "f4.cardinality()",
+ "CARDINALITY(f4)",
+ "null")
+
+ // element
+ testAllApis(
+ 'f9.element(),
+ "f9.element()",
+ "ELEMENT(f9)",
+ "1")
+
+ testAllApis(
+ 'f8.element(),
+ "f8.element()",
+ "ELEMENT(f8)",
+ "4.0")
+
+ testAllApis(
+ 'f10.element(),
+ "f10.element()",
+ "ELEMENT(f10)",
+ "null")
+
+ testAllApis(
+ 'f4.element(),
+ "f4.element()",
+ "ELEMENT(f4)",
+ "null")
+
+ // comparison
+ testAllApis(
+ 'f2 === 'f5.at(1),
+ "f2 === f5.at(1)",
+ "f2 = f5[1]",
+ "true")
+
+ testAllApis(
+ 'f6 === array(1, 2, 3),
+ "f6 === array(1, 2, 3)",
+ "f6 = ARRAY[1, 2, 3]",
+ "false")
+
+ testAllApis(
+ 'f2 !== 'f5.at(1),
+ "f2 !== f5.at(1)",
+ "f2 <> f5[1]",
+ "false")
+
+ testAllApis(
+ 'f2 === 'f7,
+ "f2 === f7",
+ "f2 = f7",
+ "false")
+
+ testAllApis(
+ 'f2 !== 'f7,
+ "f2 !== f7",
+ "f2 <> f7",
+ "true")
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ case class MyCaseClass(string: String, int: Int)
+
+ override def testData: Any = {
+ val testData = new Row(11)
+ testData.setField(0, null)
+ testData.setField(1, 42)
+ testData.setField(2, Array(1, 2, 3))
+ testData.setField(3, Array(Date.valueOf("1984-03-12"), Date.valueOf("1984-02-10")))
+ testData.setField(4, null)
+ testData.setField(5, Array(Array(1, 2, 3), null))
+ testData.setField(6, Array[Integer](1, null, null, 4))
+ testData.setField(7, Array(1, 2, 3, 4))
+ testData.setField(8, Array(4.0))
+ testData.setField(9, Array[Integer](1))
+ testData.setField(10, Array[Integer]())
+ testData
+ }
+
+ override def typeInfo: TypeInformation[Any] = {
+ new RowTypeInfo(Seq(
+ Types.INT,
+ Types.INT,
+ PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO,
+ ObjectArrayTypeInfo.getInfoFor(Types.DATE),
+ ObjectArrayTypeInfo.getInfoFor(ObjectArrayTypeInfo.getInfoFor(Types.INT)),
+ ObjectArrayTypeInfo.getInfoFor(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO),
+ ObjectArrayTypeInfo.getInfoFor(Types.INT),
+ PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO,
+ PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+ ObjectArrayTypeInfo.getInfoFor(Types.INT),
+ ObjectArrayTypeInfo.getInfoFor(Types.INT)
+ )).asInstanceOf[TypeInformation[Any]]
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/44140085/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala
index b892cfb..52dc848 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala
@@ -135,11 +135,13 @@ class SqlExpressionTest extends ExpressionTestBase {
testSqlApi("CAST(2 AS DOUBLE)", "2.0")
}
- @Ignore // TODO we need a special code path that flattens ROW types
@Test
def testValueConstructorFunctions(): Unit = {
- testSqlApi("ROW('hello world', 12)", "hello world") // test base only returns field 0
- testSqlApi("('hello world', 12)", "hello world") // test base only returns field 0
+ // TODO we need a special code path that flattens ROW types
+ // testSqlApi("ROW('hello world', 12)", "hello world") // test base only returns field 0
+ // testSqlApi("('hello world', 12)", "hello world") // test base only returns field 0
+ testSqlApi("ARRAY[TRUE, FALSE][2]", "false")
+ testSqlApi("ARRAY[TRUE, TRUE]", "[true, true]")
}
@Test
@@ -155,6 +157,12 @@ class SqlExpressionTest extends ExpressionTestBase {
testSqlApi("QUARTER(DATE '2016-04-12')", "2")
}
+ @Test
+ def testArrayFunctions(): Unit = {
+ testSqlApi("CARDINALITY(ARRAY[TRUE, TRUE, FALSE])", "3")
+ testSqlApi("ELEMENT(ARRAY['HELLO WORLD'])", "HELLO WORLD")
+ }
+
override def testData: Any = new Row(0)
override def typeInfo: TypeInformation[Any] =