You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/07/23 10:43:30 UTC
[flink] branch release-1.9 updated:
[FLINK-13314][table-planner-blink] Correct resultType of some
PlannerExpression when operands contains DecimalTypeInfo or
BigDecimalTypeInfo in Blink planner
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push:
new 5c0521d [FLINK-13314][table-planner-blink] Correct resultType of some PlannerExpression when operands contains DecimalTypeInfo or BigDecimalTypeInfo in Blink planner
5c0521d is described below
commit 5c0521d846895598ef4b10a6a66c1b803a3504a6
Author: beyond1920 <be...@126.com>
AuthorDate: Wed Jul 17 23:01:12 2019 +0800
[FLINK-13314][table-planner-blink] Correct resultType of some PlannerExpression when operands contains DecimalTypeInfo or BigDecimalTypeInfo in Blink planner
This also fix some minor bugs:
- Fix minor bug in RexNodeConverter when convert between and not between to RexNode.
- Fix minor bug in PlannerExpressionConverter when convert DataType to TypeInformation.
This closes #9152
---
.../flink/table/expressions/RexNodeConverter.java | 16 +-
.../expressions/PlannerExpressionConverter.scala | 12 +-
.../table/expressions/ReturnTypeInference.scala | 217 ++++++++
.../flink/table/expressions/arithmetic.scala | 25 +-
.../flink/table/expressions/mathExpressions.scala | 12 +-
.../table/runtime/batch/sql/DecimalITCase.scala | 5 +-
.../batch/{sql => table}/DecimalITCase.scala | 546 ++++++++-------------
7 files changed, 462 insertions(+), 371 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java
index 5528571..5dfeb97 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java
@@ -368,16 +368,24 @@ public class RexNodeConverter implements ExpressionVisitor<RexNode> {
private RexNode convertNotBetween(List<Expression> children) {
List<RexNode> childrenRexNode = convertCallChildren(children);
+ Preconditions.checkArgument(childrenRexNode.size() == 3);
+ RexNode expr = childrenRexNode.get(0);
+ RexNode lowerBound = childrenRexNode.get(1);
+ RexNode upperBound = childrenRexNode.get(2);
return relBuilder.or(
- relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, childrenRexNode),
- relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, childrenRexNode));
+ relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, expr, lowerBound),
+ relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, expr, upperBound));
}
private RexNode convertBetween(List<Expression> children) {
List<RexNode> childrenRexNode = convertCallChildren(children);
+ Preconditions.checkArgument(childrenRexNode.size() == 3);
+ RexNode expr = childrenRexNode.get(0);
+ RexNode lowerBound = childrenRexNode.get(1);
+ RexNode upperBound = childrenRexNode.get(2);
return relBuilder.and(
- relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, childrenRexNode),
- relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, childrenRexNode));
+ relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, expr, lowerBound),
+ relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, expr, upperBound));
}
private RexNode convertCeil(List<Expression> children) {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala
index 8b5dada..f53aa1e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.expressions.{E => PlannerE, UUID => PlannerUUID}
import org.apache.flink.table.functions._
import org.apache.flink.table.types.logical.LogicalTypeRoot.{CHAR, DECIMAL, SYMBOL, TIMESTAMP_WITHOUT_TIME_ZONE}
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks._
-import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo
+import org.apache.flink.table.types.TypeInfoDataTypeConverter.fromDataTypeToTypeInfo
import _root_.scala.collection.JavaConverters._
@@ -53,14 +53,14 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
assert(children.size == 2)
return Cast(
children.head.accept(this),
- fromDataTypeToLegacyInfo(
+ fromDataTypeToTypeInfo(
children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType))
case REINTERPRET_CAST =>
assert(children.size == 3)
Reinterpret(
children.head.accept(this),
- fromDataTypeToLegacyInfo(
+ fromDataTypeToTypeInfo(
children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType),
getValue[Boolean](children(2).accept(this)))
@@ -749,7 +749,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
}
}
- fromDataTypeToLegacyInfo(literal.getOutputDataType)
+ fromDataTypeToTypeInfo(literal.getOutputDataType)
}
private def getSymbol(symbol: TableSymbol): PlannerSymbol = symbol match {
@@ -786,7 +786,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
override def visit(fieldReference: FieldReferenceExpression): PlannerExpression = {
PlannerResolvedFieldReference(
fieldReference.getName,
- fromDataTypeToLegacyInfo(fieldReference.getOutputDataType))
+ fromDataTypeToTypeInfo(fieldReference.getOutputDataType))
}
override def visit(fieldReference: UnresolvedReferenceExpression)
@@ -834,7 +834,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
private def translateWindowReference(reference: Expression): PlannerExpression = reference match {
case expr : LocalReferenceExpression =>
- WindowReference(expr.getName, Some(fromDataTypeToLegacyInfo(expr.getOutputDataType)))
+ WindowReference(expr.getName, Some(fromDataTypeToTypeInfo(expr.getOutputDataType)))
//just because how the datastream is converted to table
case expr: UnresolvedReferenceExpression =>
UnresolvedFieldReference(expr.getName)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala
new file mode 100644
index 0000000..2a333ad
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.expressions
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.table.api.TableException
+import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem}
+import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType}
+import org.apache.flink.table.types.logical.{DecimalType, LogicalType}
+import org.apache.flink.table.typeutils.{BigDecimalTypeInfo, DecimalTypeInfo, TypeCoercion}
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql.`type`.SqlTypeUtil
+
+import scala.collection.JavaConverters._
+
+object ReturnTypeInference {
+
+ private lazy val typeSystem = new FlinkTypeSystem
+ private lazy val typeFactory = new FlinkTypeFactory(typeSystem)
+
+ /**
+ * Infer resultType of [[Minus]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of
+ * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MINUS]].
+ *
+ * @param minus minus Expression
+ * @return result type
+ */
+ def inferMinus(minus: Minus): TypeInformation[_] = inferPlusOrMinus(minus)
+
+ /**
+ * Infer resultType of [[Plus]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of
+ * * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS]].
+ *
+ * @param plus plus Expression
+ * @return result type
+ */
+ def inferPlus(plus: Plus): TypeInformation[_] = inferPlusOrMinus(plus)
+
+ private def inferPlusOrMinus(op: BinaryArithmetic): TypeInformation[_] = {
+ val decimalTypeInference = (
+ leftType: RelDataType,
+ rightType: RelDataType,
+ wideResultType: LogicalType) => {
+ if (SqlTypeUtil.isExactNumeric(leftType) &&
+ SqlTypeUtil.isExactNumeric(rightType) &&
+ (SqlTypeUtil.isDecimal(leftType) || SqlTypeUtil.isDecimal(rightType))) {
+ val lp = leftType.getPrecision
+ val ls = leftType.getScale
+ val rp = rightType.getPrecision
+ val rs = rightType.getScale
+ val scale = Math.max(ls, rs)
+ assert(scale <= typeSystem.getMaxNumericScale)
+ var precision = Math.max(lp - ls, rp - rs) + scale + 1
+ precision = Math.min(precision, typeSystem.getMaxNumericPrecision)
+ assert(precision > 0)
+ fromLogicalTypeToTypeInfo(wideResultType) match {
+ case _: DecimalTypeInfo => DecimalTypeInfo.of(precision, scale)
+ case _: BigDecimalTypeInfo => BigDecimalTypeInfo.of(precision, scale)
+ }
+ } else {
+ val resultType = typeFactory.leastRestrictive(
+ List(leftType, rightType).asJava)
+ fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType))
+ }
+ }
+ inferBinaryArithmetic(op, decimalTypeInference, t => fromLogicalTypeToTypeInfo(t))
+ }
+
+ /**
+ * Infer resultType of [[Mul]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.calcite.sql.type.ReturnTypes.PRODUCT_NULLABLE]] which is the return type of
+ * * * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY]].
+ *
+ * @param mul mul Expression
+ * @return result type
+ */
+ def inferMul(mul: Mul): TypeInformation[_] = {
+ val decimalTypeInference = (
+ leftType: RelDataType,
+ rightType: RelDataType) => typeFactory.createDecimalProduct(leftType, rightType)
+ inferDivOrMul(mul, decimalTypeInference)
+ }
+
+ /**
+ * Infer resultType of [[Div]] expression.
+ * The decimal type inference keeps consistent with
+ * [[org.apache.flink.table.calcite.type.FlinkReturnTypes.FLINK_QUOTIENT_NULLABLE]] which
+ * is the return type of [[org.apache.flink.table.functions.sql.FlinkSqlOperatorTable.DIVIDE]].
+ *
+ * @param div div Expression
+ * @return result type
+ */
+ def inferDiv(div: Div): TypeInformation[_] = {
+ val decimalTypeInference = (
+ leftType: RelDataType,
+ rightType: RelDataType) => typeFactory.createDecimalQuotient(leftType, rightType)
+ inferDivOrMul(div, decimalTypeInference)
+ }
+
+ private def inferDivOrMul(
+ op: BinaryArithmetic,
+ decimalTypeInfer: (RelDataType, RelDataType) => RelDataType
+ ): TypeInformation[_] = {
+ val decimalFunc = (
+ leftType: RelDataType,
+ rightType: RelDataType,
+ _: LogicalType) => {
+ val decimalType = decimalTypeInfer(leftType, rightType)
+ if (decimalType != null) {
+ fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(decimalType))
+ } else {
+ val resultType = typeFactory.leastRestrictive(
+ List(leftType, rightType).asJava)
+ fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType))
+ }
+ }
+ val nonDecimalType = op match {
+ case _: Div => (_: LogicalType) => BasicTypeInfo.DOUBLE_TYPE_INFO
+ case _: Mul => (t: LogicalType) => fromLogicalTypeToTypeInfo(t)
+ }
+ inferBinaryArithmetic(op, decimalFunc, nonDecimalType)
+ }
+
+ private def inferBinaryArithmetic(
+ binaryOp: BinaryArithmetic,
+ decimalInfer: (RelDataType, RelDataType, LogicalType) => TypeInformation[_],
+ nonDecimalInfer: LogicalType => TypeInformation[_]
+ ): TypeInformation[_] = {
+ val leftType = fromTypeInfoToLogicalType(binaryOp.left.resultType)
+ val rightType = fromTypeInfoToLogicalType(binaryOp.right.resultType)
+ TypeCoercion.widerTypeOf(leftType, rightType) match {
+ case Some(t: DecimalType) =>
+ val leftRelDataType = typeFactory.createFieldTypeFromLogicalType(leftType)
+ val rightRelDataType = typeFactory.createFieldTypeFromLogicalType(rightType)
+ decimalInfer(leftRelDataType, rightRelDataType, t)
+ case Some(t) => nonDecimalInfer(t)
+ case None => throw new TableException("This will not happen here!")
+ }
+ }
+
+ /**
+ * Infer resultType of [[Round]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.flink.table.calcite.type.FlinkReturnTypes]].ROUND_FUNCTION_NULLABLE
+ *
+ * @param round round Expression
+ * @return result type
+ */
+ def inferRound(round: Round): TypeInformation[_] = {
+ val numType = round.left.resultType
+ numType match {
+ case _: DecimalTypeInfo | _: BigDecimalTypeInfo =>
+ val lenValue = round.right match {
+ case Literal(v: Int, BasicTypeInfo.INT_TYPE_INFO) => v
+ case _ => throw new TableException("This will not happen here!")
+ }
+ val numLogicalType = fromTypeInfoToLogicalType(numType)
+ val numRelDataType = typeFactory.createFieldTypeFromLogicalType(numLogicalType)
+ val p = numRelDataType.getPrecision
+ val s = numRelDataType.getScale
+ val dt = FlinkTypeSystem.inferRoundType(p, s, lenValue)
+ fromLogicalTypeToTypeInfo(dt)
+ case t => t
+ }
+ }
+
+ /**
+ * Infer resultType of [[Floor]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE
+ *
+ * @param floor floor Expression
+ * @return result type
+ */
+ def inferFloor(floor: Floor): TypeInformation[_] = getArg0OrExactNoScale(floor)
+
+ /**
+ * Infer resultType of [[Ceil]] expression.
+ * The decimal type inference keeps consistent with Calcite
+ * [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE
+ *
+ * @param ceil ceil Expression
+ * @return result type
+ */
+ def inferCeil(ceil: Ceil): TypeInformation[_] = getArg0OrExactNoScale(ceil)
+
+ private def getArg0OrExactNoScale(op: UnaryExpression) = {
+ val childType = op.child.resultType
+ childType match {
+ case t: DecimalTypeInfo => DecimalTypeInfo.of(t.precision(), 0)
+ case t: BigDecimalTypeInfo => BigDecimalTypeInfo.of(t.precision(), 0)
+ case _ => childType
+ }
+ }
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala
index 726d9ff..20a4ba2 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala
@@ -17,10 +17,10 @@
*/
package org.apache.flink.table.expressions
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType}
-import org.apache.flink.table.typeutils.{DecimalTypeInfo, TypeCoercion}
+import org.apache.flink.table.typeutils.TypeCoercion
import org.apache.flink.table.typeutils.TypeInfoCheckUtils._
import org.apache.flink.table.validate._
@@ -71,6 +71,10 @@ case class Plus(left: PlannerExpression, right: PlannerExpression) extends Binar
s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.")
}
}
+
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferPlus(this)
+ }
}
case class UnaryMinus(child: PlannerExpression) extends UnaryExpression {
@@ -111,6 +115,10 @@ case class Minus(left: PlannerExpression, right: PlannerExpression) extends Bina
s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.")
}
}
+
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferMinus(this)
+ }
}
case class Div(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
@@ -118,17 +126,20 @@ case class Div(left: PlannerExpression, right: PlannerExpression) extends Binary
private[flink] val sqlOperator = FlinkSqlOperatorTable.DIVIDE
- override private[flink] def resultType: TypeInformation[_] =
- super.resultType match {
- case dt: DecimalTypeInfo => dt
- case _ => BasicTypeInfo.DOUBLE_TYPE_INFO
- }
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferDiv(this)
+ }
+
}
case class Mul(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
override def toString = s"($left * $right)"
private[flink] val sqlOperator = FlinkSqlOperatorTable.MULTIPLY
+
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferMul(this)
+ }
}
case class Mod(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
index 7c9d3fd..c28d2b8 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
@@ -32,7 +32,9 @@ case class Abs(child: PlannerExpression) extends UnaryExpression {
}
case class Ceil(child: PlannerExpression) extends UnaryExpression {
- override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferCeil(this)
+ }
override private[flink] def validateInput(): ValidationResult =
TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Ceil")
@@ -50,7 +52,9 @@ case class Exp(child: PlannerExpression) extends UnaryExpression with InputTypeS
case class Floor(child: PlannerExpression) extends UnaryExpression {
- override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferFloor(this)
+ }
override private[flink] def validateInput(): ValidationResult =
TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Floor")
@@ -258,7 +262,9 @@ case class Sign(child: PlannerExpression) extends UnaryExpression {
case class Round(left: PlannerExpression, right: PlannerExpression)
extends BinaryExpression {
- override private[flink] def resultType: TypeInformation[_] = left.resultType
+ override private[flink] def resultType: TypeInformation[_] = {
+ ReturnTypeInference.inferRound(this)
+ }
override private[flink] def validateInput(): ValidationResult = {
if (!TypeInfoCheckUtils.isInteger(right.resultType)) {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala
index 0fc7e2e..137d277 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala
@@ -28,7 +28,7 @@ import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.fromLogicalType
import org.apache.flink.table.types.logical.{DecimalType, LogicalType}
import org.apache.flink.types.Row
-import org.junit.{Assert, Ignore, Test}
+import org.junit.{Assert, Test}
import java.math.{BigDecimal => JBigDecimal}
@@ -591,7 +591,6 @@ class DecimalITCase extends BatchTestBase {
s1r(null))
}
- @Ignore
@Test
def testAggMinMaxCount(): Unit = {
@@ -862,7 +861,6 @@ class DecimalITCase extends BatchTestBase {
s1r(1L))
}
- @Ignore
@Test
def testGroupBy(): Unit = {
checkQuery1(
@@ -896,7 +894,6 @@ class DecimalITCase extends BatchTestBase {
s1r(d"100.000", null, null))
}
- @Ignore
@Test
def testAggAvgGroupBy(): Unit = {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala
similarity index 59%
copy from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala
copy to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala
index 0fc7e2e..8bb6054 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala
@@ -16,87 +16,61 @@
* limitations under the License.
*/
-package org.apache.flink.table.runtime.batch.sql
+package org.apache.flink.table.runtime.batch.table
import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.table.api.{DataTypes, ExecutionConfigOptions}
-import org.apache.flink.table.runtime.utils.BatchTestBase
+import org.apache.flink.table.api.{DataTypes, ExecutionConfigOptions, Table}
+import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.BatchTestBase.row
+import org.apache.flink.table.runtime.utils.{BatchTableEnvUtil, BatchTestBase}
import org.apache.flink.table.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.fromLogicalTypeToTypeInfo
import org.apache.flink.table.types.logical.{DecimalType, LogicalType}
import org.apache.flink.types.Row
-import org.junit.{Assert, Ignore, Test}
+import org.junit.{Assert, Test}
import java.math.{BigDecimal => JBigDecimal}
import scala.collection.Seq
/**
- * Conformance test of SQL type Decimal(p,s).
+ * Conformance test of TableApi type Decimal(p,s).
* Served also as documentation of our Decimal behaviors.
*/
class DecimalITCase extends BatchTestBase {
- private case class Coll(colTypes: Seq[LogicalType], rows: Seq[Row])
-
- private var globalTableId = 0
- private def checkQueryX(
- tables: Seq[Coll],
- query: String,
- expected: Coll,
- isSorted: Boolean = false)
- : Unit = {
-
- var tableId = 0
- var queryX = query
- tables.foreach{ table =>
- tableId += 1
- globalTableId += 1
- val tableName = "Table" + tableId
- val newTableName = tableName + "_" + globalTableId
- val rowTypeInfo = new RowTypeInfo(table.colTypes.toArray.map(fromLogicalTypeToTypeInfo): _*)
- val fieldNames = rowTypeInfo.getFieldNames.mkString(",")
- registerCollection(newTableName, table.rows, rowTypeInfo, fieldNames)
- queryX = queryX.replace(tableName, newTableName)
- }
+ private def checkQuery(
+ sourceColTypes: Seq[LogicalType],
+ sourceRows: Seq[Row],
+ tableTransfer: Table => Table,
+ expectedColTypes: Seq[LogicalType],
+ expectedRows: Seq[Row],
+ isSorted: Boolean = false): Unit = {
+ val rowTypeInfo = new RowTypeInfo(sourceColTypes.toArray.map(fromLogicalTypeToTypeInfo): _*)
+ val fieldNames = rowTypeInfo.getFieldNames.mkString(",")
+ val t = BatchTableEnvUtil.fromCollection(tEnv, sourceRows, rowTypeInfo, fieldNames)
// check result schema
- val resultTable = parseQuery(queryX)
- val ts1 = expected.colTypes
+ val resultTable = tableTransfer(t)
val ts2 = resultTable.getSchema.getFieldDataTypes.map(fromDataTypeToLogicalType)
- Assert.assertEquals(ts1.length, ts2.length)
+ Assert.assertEquals(expectedColTypes.length, ts2.length)
- Assert.assertTrue(ts1.zip(ts2).forall {
+ Assert.assertTrue(expectedColTypes.zip(ts2).forall {
case (t1, t2) => isInteroperable(t1, t2)
})
def prepareResult(isSorted: Boolean, seq: Seq[Row]) = {
if (!isSorted) seq.map(_.toString).sortBy(s => s) else seq.map(_.toString)
}
+
val resultRows = executeQuery(resultTable)
Assert.assertEquals(
- prepareResult(isSorted, expected.rows),
+ prepareResult(isSorted, expectedRows),
prepareResult(isSorted, resultRows))
}
- private def checkQuery1(
- sourceColTypes: Seq[LogicalType],
- sourceRows: Seq[Row],
- query: String,
- expectedColTypes: Seq[LogicalType],
- expectedRows: Seq[Row],
- isSorted: Boolean = false)
- : Unit = {
- checkQueryX(
- Seq(Coll(sourceColTypes, sourceRows)),
- query,
- Coll(expectedColTypes, expectedRows),
- isSorted)
- }
-
// a Seq of one Row
private def s1r(args: Any*): Seq[Row] = Seq(row(args: _*))
@@ -122,9 +96,13 @@ class DecimalITCase extends BatchTestBase {
private def DECIMAL = (p: Int, s: Int) => new DecimalType(p, s)
private def BOOL = DataTypes.BOOLEAN.getLogicalType
+
private def INT = DataTypes.INT.getLogicalType
+
private def LONG = DataTypes.BIGINT.getLogicalType
+
private def DOUBLE = DataTypes.DOUBLE.getLogicalType
+
private def STRING = DataTypes.STRING.getLogicalType
// d"xxx" => new BigDecimal("xxx")
@@ -145,87 +123,62 @@ class DecimalITCase extends BatchTestBase {
def testDataSource(): Unit = {
// the most basic case
- checkQuery1(
+
+ checkQuery(
Seq(DECIMAL(10, 0), DECIMAL(7, 2)),
s1r(d"123", d"123.45"),
- "select * from Table1",
+ table => table.select('*),
Seq(DECIMAL(10, 0), DECIMAL(7, 2)),
s1r(d"123", d"123.45"))
// data from source are rounded to their declared scale before entering next step
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(7, 2)),
s1r(d"100.004"),
- "select f0, f0+f0 from Table1", // 100.00+100.00
+ table => table.select('f0, 'f0 + 'f0), // 100.00+100.00
Seq(DECIMAL(7, 2), DECIMAL(8, 2)),
- s1r(d"100.00", d"200.00")) // not 200.008=>200.01
+ s1r(d"100.00", d"200.00")) // not 200.008=>200.01
// trailing zeros are padded to the scale
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(7, 2)),
s1r(d"100.1"),
- "select f0, f0+f0 from Table1", // 100.00+100.00
+ table => table.select('f0, 'f0 + 'f0), // 100.00+100.00
Seq(DECIMAL(7, 2), DECIMAL(8, 2)),
s1r(d"100.10", d"200.20"))
// source data is within precision after rounding
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(5, 2)),
s1r(d"100.0040"), // p=7 => rounding => p=5
- "select f0, f0+f0 from Table1",
+ table => table.select('f0, 'f0 + 'f0), // 100.00+100.00
Seq(DECIMAL(5, 2), DECIMAL(6, 2)),
s1r(d"100.00", d"200.00"))
// source data overflows over precision (after rounding)
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(2, 0)),
s1r(d"123"),
- "select * from Table1",
+ table => table.select('*),
Seq(DECIMAL(2, 0)),
s1r(null))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(4, 2)),
s1r(d"123.0000"),
- "select * from Table1",
+ table => table.select('*),
Seq(DECIMAL(4, 2)),
s1r(null))
}
@Test
- def testLiterals(): Unit = {
-
- checkQuery1(
- Seq(DECIMAL(1,0)),
- s1r(d"0"),
- "select 12, 12.3, 12.34 from Table1",
- Seq(INT, DECIMAL(3, 1), DECIMAL(4, 2)),
- s1r(12, d"12.3", d"12.34"))
-
- checkQuery1(
- Seq(DECIMAL(1,0)),
- s1r(d"0"),
- "select 123456789012345678901234567890.12345678 from Table1",
- Seq(DECIMAL(38, 8)),
- s1r(d"123456789012345678901234567890.12345678"))
-
- expectOverflow(()=>
- checkQuery1(
- Seq(DECIMAL(1,0)),
- s1r(d"0"),
- "select 123456789012345678901234567890.123456789 from Table1",
- Seq(DECIMAL(38, 9)),
- s1r(d"123456789012345678901234567890.123456789")))
- }
-
- @Test
def testUnaryPlusMinus(): Unit = {
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 0), DECIMAL(7, 2)),
s1r(d"123", d"123.45"),
- "select +f0, -f1, -((+f0)-(-f1)) from Table1",
- Seq(DECIMAL(10, 0), DECIMAL(7, 2), DECIMAL(13,2)),
+ table => table.select( + 'f0, - 'f1, - (( + 'f0) - ( - 'f1))),
+ Seq(DECIMAL(10, 0), DECIMAL(7, 2), DECIMAL(13, 2)),
s1r(d"123", d"-123.45", d"-246.45"))
}
@@ -235,19 +188,19 @@ class DecimalITCase extends BatchTestBase {
// see calcite ReturnTypes.DECIMAL_SUM
// s = max(s1,s2), p-s = max(p1-s1, p2-s2) + 1
// p then is capped at 38
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2), DECIMAL(10, 4)),
s1r(d"100.12", d"200.1234"),
- "select f0+f1, f0-f1 from Table1",
+ table => table.select('f0 + 'f1, 'f0 - 'f1),
Seq(DECIMAL(13, 4), DECIMAL(13, 4)),
s1r(d"300.2434", d"-100.0034"))
// INT => DECIMAL(10,0)
// approximate + exact => approximate
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2), INT, DOUBLE),
s1r(d"100.00", 200, 3.14),
- "select f0+f1, f1+f0, f0+f2, f2+f0 from Table1",
+ table => table.select('f0 + 'f1, 'f1 + 'f0, 'f0 + 'f2, 'f2 + 'f0),
Seq(DECIMAL(13, 2), DECIMAL(13, 2), DOUBLE, DOUBLE),
s1r(d"300.00", d"300.00", d"103.14", d"103.14"))
@@ -257,31 +210,31 @@ class DecimalITCase extends BatchTestBase {
// (38,10)+(38,28)=>(57,28)=>(38,28)
// T-SQL -- scale may be reduced to keep the integral part. approximation may occur
// (38,10)+(38,28)=>(57,28)=>(38,9)
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 10), DECIMAL(38, 28)),
s1r(d"100.0123456789", d"200.0123456789012345678901234567"),
- "select f0+f1, f0-f1 from Table1",
+ table => table.select('f0 + 'f1, 'f0 - 'f1),
Seq(DECIMAL(38, 28), DECIMAL(38, 28)),
s1r(d"300.0246913578012345678901234567", d"-100.0000000000012345678901234567"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 10), DECIMAL(38, 28)),
s1r(d"1e10", d"0"),
- "select f1+f0, f1-f0 from Table1",
+ table => table.select('f1 + 'f0, 'f1 - 'f0 ),
Seq(DECIMAL(38, 28), DECIMAL(38, 28)), // 10 digits integral part
s1r(null, null))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 0)),
s1r(d"5e37"),
- "select f0+f0 from Table1",
+ table => table.select('f0 + 'f0 ),
Seq(DECIMAL(38, 0)),
s1r(null)) // requires 39 digits
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 0), DECIMAL(38, 0)),
s1r(d"5e37", d"5e37"),
- "select f0+f0-f1 from Table1", // overflows in subexpression
+ table => table.select('f0 + 'f0 -'f1 ), // overflows in subexpression
Seq(DECIMAL(38, 0)),
s1r(null))
}
@@ -293,207 +246,166 @@ class DecimalITCase extends BatchTestBase {
// s = s1+s2, p = p1+p2
// both p&s are capped at 38
// if s>38, result is rounded to s=38, and the integral part can only be zero
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(5, 2), DECIMAL(10, 4)),
s1r(d"1.00", d"2.0000"),
- "select f0*f0, f0*f1 from Table1",
+ table => table.select('f0*'f0, 'f0*'f1 ),
Seq(DECIMAL(10, 4), DECIMAL(15, 6)),
s1r(d"1.0000", d"2.000000"))
// INT => DECIMAL(10,0)
// approximate * exact => approximate
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2), INT, DOUBLE),
s1r(d"1.00", 200, 3.14),
- "select f0*f1, f1*f0, f0*f2, f2*f0 from Table1",
+ table => table.select('f0*'f1, 'f1*'f0, 'f0*'f2, 'f2*'f0 ),
Seq(DECIMAL(20, 2), DECIMAL(20, 2), DOUBLE, DOUBLE),
s1r(d"200.00", d"200.00", 3.14, 3.14))
// precision is capped at 38; scale will not be reduced (unless over 38)
// similar to plus&minus, and calcite behavior is different from T-SQL.
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(30, 6), DECIMAL(30, 10)),
s1r(d"1", d"2"),
- "select f0*f0, f0*f1 from Table1",
+ table => table.select('f0*'f0, 'f0*'f1 ),
Seq(DECIMAL(38, 12), DECIMAL(38, 16)),
s1r(d"1${12}", d"2${16}"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(30, 20)),
s1r(d"0.1"),
- "select f0*f0 from Table1",
- Seq(DECIMAL(38, 38)), // (60,40)=>(38,38)
+ table => table.select('f0*'f0 ),
+ Seq(DECIMAL(38, 38)), // (60,40)=>(38,38)
s1r(d"0.01${38}"))
// scalastyle:off
// we don't have this ridiculous behavior:
- // https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/
+ // https://blogs.msdn.microsoft
+ // .com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/
// scalastyle:on
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 10), DECIMAL(38, 10)),
s1r(d"0.0000006", d"1.0"),
- "select f0*f1 from Table1",
+ table => table.select('f0*'f1 ),
Seq(DECIMAL(38, 20)),
s1r(d"0.0000006${20}"))
// result overflow
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 0)),
s1r(d"1e19"),
- "select f0*f0 from Table1",
+ table => table.select('f0*'f0 ),
Seq(DECIMAL(38, 0)),
s1r(null))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(30, 20)),
s1r(d"1.0"),
- "select f0*f0 from Table1",
- Seq(DECIMAL(38, 38)), // (60,40)=>(38,38), no space for integral part
+ table => table.select('f0*'f0 ),
+ Seq(DECIMAL(38, 38)), // (60,40)=>(38,38), no space for integral part
s1r(null))
}
@Test
def testDivide(): Unit = {
- // the default impl of Calcite apparently borrows from T-SQL, but differs in details.
- // Flink overrides it to follow T-SQL exactly. See FlinkTypeFactory.createDecimalQuotient()
- checkQuery1( // test 1/3 in different scales
+// // the default impl of Calcite apparently borrows from T-SQL, but differs in details.
+// // Flink overrides it to follow T-SQL exactly. See FlinkTypeFactory.createDecimalQuotient()
+ checkQuery( // test 1/3 in different scales
Seq(DECIMAL(20, 2), DECIMAL(2, 1), DECIMAL(4, 3), DECIMAL(20, 10), DECIMAL(20, 16)),
s1r(d"1.00", d"3", d"3", d"3", d"3"),
- "select f0/f1, f0/f2, f0/f3, f0/f4 from Table1",
+ table => table.select('f0/'f1, 'f0/'f2, 'f0/'f3, 'f0/'f4 ),
Seq(DECIMAL(25, 6), DECIMAL(28, 7), DECIMAL(38, 10), DECIMAL(38, 6)),
s1r(d"0.333333", d"0.3333333", d"0.3333333333", d"0.333333"))
// INT => DECIMAL(10,0)
// approximate / exact => approximate
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2), INT, DOUBLE),
s1r(d"1.00", 2, 3.0),
- "select f0/f1, f1/f0, f0/f2, f2/f0 from Table1",
+ table => table.select('f0/'f1, 'f1/'f0, 'f0/'f2, 'f2/'f0 ),
Seq(DECIMAL(21, 13), DECIMAL(23, 11), DOUBLE, DOUBLE),
- s1r(d"0.5${13}", d"2${11}", 1.0/3.0, 3.0/1.0))
+ s1r(d"0.5${13}", d"2${11}", 1.0 / 3.0, 3.0 / 1.0))
// result overflow, because result type integral part is reduced
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(30, 0), DECIMAL(30, 20)),
s1r(d"1e20", d"1e-15"),
- "select f0/f1 from Table1",
+ table => table.select('f0/'f1 ),
Seq(DECIMAL(38, 6)),
s1r(null))
}
@Test
def testMod(): Unit = {
-
- // MOD(Exact1, Exact2) => Exact2
- checkQuery1(
- Seq(DECIMAL(10, 2), DECIMAL(10, 4), INT),
- s1r(d"3.00", d"5.0000", 7),
- "select mod(f0,f1), mod(f1,f0), mod(f0,f2), mod(f2,f0) from Table1",
- Seq(DECIMAL(10, 4), DECIMAL(10, 2), INT, DECIMAL(10, 2)),
- s1r(d"3.0000", d"2.00", 3, d"1.00"))
-
// signs. consistent with Java's % operator.
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(1, 0), DECIMAL(1, 0)),
s1r(d"3", d"5"),
- "select mod(f0,f1), mod(-f0,f1), mod(f0,-f1), mod(-f0,-f1) from Table1",
+ table => table.select('f0 % 'f1, (-'f0) % 'f1,'f0 % (-'f1), (-'f0) % (-'f1)),
Seq(DECIMAL(1, 0), DECIMAL(1, 0), DECIMAL(1, 0), DECIMAL(1, 0)),
- s1r(3%5, (-3)%5, 3%(-5), (-3)%(-5)))
+ s1r(3 % 5, (-3) % 5, 3 % (-5), (-3) % (-5)))
// rounding in case s1>s2. note that SQL2003 requires s1=s2=0.
// (In T-SQL, s2 is expanded to s1, so that there's no rounding.)
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 4), DECIMAL(10, 2)),
s1r(d"3.1234", d"5"),
- "select mod(f0,f1) from Table1",
+ table => table.select('f0 % 'f1),
Seq(DECIMAL(10, 2)),
s1r(d"3.12"))
}
- @Test
- def testDiv(): Unit = {
-
- // see DivCallGen
- checkQuery1(
- Seq(DECIMAL(7, 0), INT),
- s1r(d"7", 2),
- "select div(f0, f1), div(100*f1, f0) from Table1",
- Seq(DECIMAL(7, 0), DECIMAL(10, 0)),
- s1r(3, 200 / 7))
-
- checkQuery1(
- Seq(DECIMAL(10, 1), DECIMAL(10, 3)),
- s1r(d"7.9", d"2.009"),
- "select div(f0, f1), div(100*f1, f0) from Table1",
- Seq(DECIMAL(12, 0), DECIMAL(18, 0)),
- s1r(3, 2009 / 79))
- }
-
@Test // functions that treat Decimal as exact value
def testExactFunctions(): Unit = {
- checkQuery1(
- Seq(DECIMAL(10, 2), DECIMAL(10, 2)),
- s1r(d"3.14", d"2.17"),
- "select if(f0>f1, f0, f1) from Table1",
- Seq(DECIMAL(10, 2)),
- s1r(d"3.14"))
-
- checkQuery1(
- Seq(DECIMAL(10, 2)),
- s1r(d"3.14"),
- "select abs(f0), abs(-f0) from Table1",
- Seq(DECIMAL(10, 2), DECIMAL(10, 2)),
- s1r(d"3.14", d"3.14"))
-
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2)),
s1r(d"3.14"),
- "select floor(f0), ceil(f0) from Table1",
+ table => table.select('f0.floor, 'f0.ceil),
Seq(DECIMAL(10, 0), DECIMAL(10, 0)),
s1r(d"3", d"4"))
// calcite: SIGN(Decimal(p,s))=>Decimal(p,s)
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2)),
s1r(d"3.14"),
- "select sign(f0), sign(-f0), sign(f0-f0) from Table1",
+ table => table.select('f0.sign, (-'f0).sign, ('f0 - 'f0).sign ),
Seq(DECIMAL(10, 2), DECIMAL(10, 2), DECIMAL(11, 2)),
s1r(d"1.00", d"-1.00", d"0.00"))
// ROUND(Decimal(p,s)[,INT])
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 3)),
s1r(d"646.646"),
- "select round(f0), round(f0, 0) from Table1",
+ table => table.select('f0.round(0), 'f0.round(0)),
Seq(DECIMAL(8, 0), DECIMAL(8, 0)),
s1r(d"647", d"647"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 3)),
s1r(d"646.646"),
- "select round(f0,1), round(f0,2), round(f0,3), round(f0,4) from Table1",
+ table => table.select('f0.round(1), 'f0.round(2), 'f0.round(3), 'f0.round(4) ),
Seq(DECIMAL(9, 1), DECIMAL(10, 2), DECIMAL(10, 3), DECIMAL(10, 3)),
s1r(d"646.6", d"646.65", d"646.646", d"646.646"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 3)),
s1r(d"646.646"),
- "select round(f0,-1), round(f0,-2), round(f0,-3), round(f0,-4) from Table1",
+ table => table.select('f0.round(-1), 'f0.round(-2), 'f0.round(-3), 'f0.round(-4) ),
Seq(DECIMAL(8, 0), DECIMAL(8, 0), DECIMAL(8, 0), DECIMAL(8, 0)),
s1r(d"650", d"600", d"1000", d"0"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(4, 2)),
s1r(d"99.99"),
- "select round(f0,1), round(-f0,1), round(f0,-1), round(-f0,-1) from Table1",
+ table => table.select('f0.round(1), (-'f0).round(1), 'f0.round(-1), (-'f0).round(-1) ),
Seq(DECIMAL(4, 1), DECIMAL(4, 1), DECIMAL(3, 0), DECIMAL(3, 0)),
s1r(d"100.0", d"-100.0", d"100", d"-100"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(38, 0)),
s1r(d"1E38".subtract(d"1")),
- "select round(f0,-1) from Table1",
+ table => table.select('f0.round(-1) ),
Seq(DECIMAL(38, 0)),
s1r(null))
}
@@ -503,52 +415,24 @@ class DecimalITCase extends BatchTestBase {
import java.lang.Math._
- checkQuery1(
- Seq(DECIMAL(10, 2)),
- s1r(d"3.14"),
- "select log10(f0), ln(f0), log(f0), log2(f0) from Table1",
- Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE),
- s1r(log10(3.14), Math.log(3.14), Math.log(3.14), Math.log(3.14)/Math.log(2.0)))
-
- checkQuery1(
- Seq(DECIMAL(10, 2), DOUBLE),
- s1r(d"3.14", 3.14),
- "select log(f0,f0), log(f0,f1), log(f1,f0) from Table1",
- Seq(DOUBLE, DOUBLE, DOUBLE),
- s1r(1.0, 1.0, 1.0))
-
- checkQuery1(
- Seq(DECIMAL(10, 2), DOUBLE),
- s1r(d"3.14", 0.3),
- "select power(f0,f0), power(f0,f1), power(f1,f0), sqrt(f0) from Table1",
- Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE),
- s1r(pow(3.14, 3.14), pow(3.14, 0.3), pow(0.3, 3.14), pow(3.14, 0.5)))
-
- checkQuery1(
- Seq(DECIMAL(10, 2), DOUBLE),
- s1r(d"3.14", 0.3),
- "select exp(f0), exp(f1) from Table1",
- Seq(DOUBLE, DOUBLE),
- s1r(exp(3.14), exp(0.3)))
-
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2)),
s1r(d"0.12"),
- "select sin(f0), cos(f0), tan(f0), cot(f0) from Table1",
+ table => table.select('f0.sin, 'f0.cos, 'f0.tan, 'f0.cot ),
Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE),
- s1r(sin(0.12), cos(0.12), tan(0.12), 1.0/tan(0.12)))
+ s1r(sin(0.12), cos(0.12), tan(0.12), 1.0 / tan(0.12)))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2)),
s1r(d"0.12"),
- "select asin(f0), acos(f0), atan(f0) from Table1",
+ table => table.select('f0.asin, 'f0.acos, 'f0.atan ),
Seq(DOUBLE, DOUBLE, DOUBLE),
s1r(asin(0.12), acos(0.12), atan(0.12)))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 2)),
s1r(d"0.12"),
- "select degrees(f0), radians(f0) from Table1",
+ table => table.select('f0.degrees, 'f0.radians),
Seq(DOUBLE, DOUBLE),
s1r(toDegrees(0.12), toRadians(0.12)))
}
@@ -557,17 +441,17 @@ class DecimalITCase extends BatchTestBase {
def testAggSum(): Unit = {
// SUM(Decimal(p,s))=>Decimal(38,s)
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3)),
(0 until 100).map(_ => row(d"1.000")),
- "select sum(f0) from Table1",
+ table => table.select('f0.sum ),
Seq(DECIMAL(38, 3)),
s1r(d"100.000"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(37, 0)),
(0 until 100).map(_ => row(d"1e36")),
- "select sum(f0) from Table1",
+ table => table.select('f0.sum ),
Seq(DECIMAL(38, 0)),
s1r(null))
}
@@ -576,107 +460,87 @@ class DecimalITCase extends BatchTestBase {
def testAggAvg(): Unit = {
// AVG(Decimal(p,s)) => Decimal(38,s)/Decimal(20,0) => Decimal(38, max(s,6))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(20, 10)),
(0 until 100).map(_ => row(d"100.000", d"1${10}")),
- "select avg(f0), avg(f1) from Table1",
+ table => table.select('f0.avg, 'f1.avg ),
Seq(DECIMAL(38, 6), DECIMAL(38, 10)),
s1r(d"100.000000", d"1${10}"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(37, 0)),
(0 until 100).map(_ => row(d"1e36")),
- "select avg(f0) from Table1",
+ table => table.select('f0.avg),
Seq(DECIMAL(38, 6)),
s1r(null))
}
- @Ignore
@Test
def testAggMinMaxCount(): Unit = {
// MIN/MAX(T) => T
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3)),
(10 to 90).map(i => row(java.math.BigDecimal.valueOf(i))),
- "select min(f0), max(f0), count(f0) from Table1",
+ table => table.select('f0.min, 'f0.max, 'f0.count ),
Seq(DECIMAL(6, 3), DECIMAL(6, 3), LONG),
s1r(d"10.000", d"90.000", 81L))
}
@Test
- def testCaseWhen(): Unit = {
-
- // result type: SQL2003 $9.23, calcite RelDataTypeFactory.leastRestrictive()
- checkQuery1(
- Seq(DECIMAL(8, 4), DECIMAL(10, 2)),
- s1r(d"0.0001", d"0.01"),
- "select case f0 when 0 then f0 else f1 end from Table1",
- Seq(DECIMAL(12, 4)),
- s1r(d"0.0100"))
-
- checkQuery1(
- Seq(DECIMAL(8, 4), INT),
- s1r(d"0.0001", 1),
- "select case f0 when 0 then f0 else f1 end from Table1",
- Seq(DECIMAL(14, 4)),
- s1r(d"1.0000"))
-
- checkQuery1(
- Seq(DECIMAL(8, 4), DOUBLE),
- s1r(d"0.0001", 3.14),
- "select case f0 when 0 then f1 else f0 end from Table1",
- Seq(DOUBLE),
- s1r(d"0.0001".doubleValue()))
- }
-
- @Test
def testCast(): Unit = {
// String, numeric/Decimal => Decimal
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), INT, DOUBLE, STRING),
s1r(d"3.14", 3, 3.14, "3.14"),
- "select cast(f0 as Decimal(8,4)), cast(f1 as Decimal(8,4)), " +
- "cast(f2 as Decimal(8,4)), cast(f3 as Decimal(8,4)) from Table1",
+ table => table.select('f0.cast(DataTypes.DECIMAL(8,4)),
+ 'f1.cast(DataTypes.DECIMAL(8,4)),
+ 'f2.cast(DataTypes.DECIMAL(8,4)),
+ 'f3.cast(DataTypes.DECIMAL(8,4)) ),
Seq(DECIMAL(8, 4), DECIMAL(8, 4), DECIMAL(8, 4), DECIMAL(8, 4)),
s1r(d"3.1400", d"3.0000", d"3.1400", d"3.1400"))
// round up
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DOUBLE, STRING),
s1r(d"3.15", 3.15, "3.15"),
- "select cast(f0 as Decimal(8,1)), cast(f1 as Decimal(8,1)), " +
- "cast(f2 as Decimal(8,1)) from Table1",
+ table => table.select(
+ 'f0.cast(DataTypes.DECIMAL(8,1)),
+ 'f1.cast(DataTypes.DECIMAL(8,1)),
+ 'f2.cast(DataTypes.DECIMAL(8,1))),
Seq(DECIMAL(8, 1), DECIMAL(8, 1), DECIMAL(8, 1)),
s1r(d"3.2", d"3.2", d"3.2"))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(4, 2)),
s1r(d"13.14"),
- "select cast(f0 as Decimal(3,2)) from Table1",
+ table => table.select('f0.cast(DataTypes.DECIMAL(3,2)) ),
Seq(DECIMAL(3, 2)),
s1r(null))
- checkQuery1(
+ checkQuery(
Seq(STRING),
s1r("13.14"),
- "select cast(f0 as Decimal(3,2)) from Table1",
+ table => table.select('f0.cast(DataTypes.DECIMAL(3,2)) ),
Seq(DECIMAL(3, 2)),
s1r(null))
// Decimal => String, numeric
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(4, 2)),
s1r(d"1.99"),
- "select cast(f0 as VARCHAR(64)), cast(f0 as DOUBLE), cast(f0 as INT) from Table1",
+ table => table.select(
+ 'f0.cast(DataTypes.VARCHAR(64)),
+ 'f0.cast(DataTypes.DOUBLE),
+ 'f0.cast(DataTypes.INT)),
Seq(STRING, DOUBLE, INT),
s1r("1.99", 1.99, 1))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(10, 0), DECIMAL(10, 0)),
s1r(d"-2147483648", d"2147483647"),
- "select cast(f0 as INT), cast(f1 as INT) from Table1",
+ table => table.select('f0.cast(DataTypes.INT), 'f1.cast(DataTypes.INT)),
Seq(INT, INT),
s1r(-2147483648, 2147483647))
}
@@ -687,46 +551,31 @@ class DecimalITCase extends BatchTestBase {
// expressions that test equality.
// =, CASE, NULLIF, IN, IS DISTINCT FROM
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select f0=f1, f0=f2, f0=f3, f1=f0, f2=f0, f3=f0 from Table1",
+ table => table.select('f0==='f1, 'f0==='f2, 'f0==='f3, 'f1==='f0, 'f2==='f0, 'f3==='f0 ),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(true, true, true, true, true, true))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select f0 IN(f1), f0 IN(f2), f0 IN(f3), " +
- "f1 IN(f0), f2 IN(f0), f3 IN(f0) from Table1",
+ table => table.select('f0.in('f1), 'f0.in('f2), 'f0.in('f3),
+ 'f1.in('f0), 'f2.in('f0), 'f3.in('f0) ),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(true, true, true, true, true, true))
- checkQuery1(
- Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
- s1r(d"1", d"1", 1, 1.0),
- "select " +
- "f0 IS DISTINCT FROM f1, f1 IS DISTINCT FROM f0, " +
- "f0 IS DISTINCT FROM f2, f2 IS DISTINCT FROM f0, " +
- "f0 IS DISTINCT FROM f3, f3 IS DISTINCT FROM f0 from Table1",
- Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
- s1r(false, false, false, false, false, false))
-
- checkQuery1(
- Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
- s1r(d"1", d"1", 1, 1.0),
- "select NULLIF(f0,f1), NULLIF(f0,f2), NULLIF(f0,f3)," +
- "NULLIF(f1,f0), NULLIF(f2,f0), NULLIF(f3,f0) from Table1",
- Seq(DECIMAL(8, 2), DECIMAL(8, 2), DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
- s1r(null, null, null, null, null, null))
-
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select " +
- "case f0 when f1 then 1 else 0 end, case f1 when f0 then 1 else 0 end, " +
- "case f0 when f2 then 1 else 0 end, case f2 when f0 then 1 else 0 end, " +
- "case f0 when f3 then 1 else 0 end, case f3 when f0 then 1 else 0 end from Table1",
+ table => table.select(
+ ('f0 === 'f1) ? (1, 0),
+ ('f1 === 'f0) ?(1, 0),
+ ('f0 === 'f2) ? (1, 0),
+ ('f2 === 'f0) ? (1, 0),
+ ('f0 === 'f3) ? (1, 0),
+ ('f3 === 'f0) ? (1, 0)),
Seq(INT, INT, INT, INT, INT, INT),
s1r(1, 1, 1, 1, 1, 1))
}
@@ -734,39 +583,45 @@ class DecimalITCase extends BatchTestBase {
@Test
def testComparison(): Unit = {
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select f0<f1, f0<f2, f0<f3, f1<f0, f2<f0, f3<f0 from Table1",
+ table => table.select('f0<'f1, 'f0<'f2, 'f0<'f3, 'f1<'f0, 'f2<'f0, 'f3<'f0 ),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(false, false, false, false, false, false))
// no overflow during type conversion.
// conceptually both operands are promoted to infinite precision before comparison.
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(1, 0), DECIMAL(2, 0), INT, DOUBLE),
s1r(d"1", d"99", 99, 99.0),
- "select f0<f1, f0<f2, f0<f3, f1<f0, f2<f0, f3<f0 from Table1",
+ table => table.select('f0<'f1, 'f0<'f2, 'f0<'f3, 'f1<'f0, 'f2<'f0, 'f3<'f0 ),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(true, true, true, false, false, false))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select " +
- "f0 between f1 and 1, f1 between f0 and 1, " +
- "f0 between f2 and 1, f2 between f0 and 1, " +
- "f0 between f3 and 1, f3 between f0 and 1 from Table1",
+ table => table.select(
+ 'f0.between('f1, 1),
+ 'f1.between('f0, 1),
+ 'f0.between('f2, 1),
+ 'f2.between('f0, 1),
+ 'f0.between('f3, 1),
+ 'f3.between('f0, 1)),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(true, true, true, true, true, true))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select " +
- "f0 between 0 and f1, f1 between 0 and f0, " +
- "f0 between 0 and f2, f2 between 0 and f0, " +
- "f0 between 0 and f3, f3 between 0 and f0 from Table1",
+ table => table.select(
+ 'f0.between(0, 'f1),
+ 'f1.between(0, 'f0),
+ 'f0.between(0, 'f2),
+ 'f2.between(0, 'f0),
+ 'f0.between(0, 'f3),
+ 'f3.between(0, 'f0)),
Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL),
s1r(true, true, true, true, true, true))
}
@@ -776,10 +631,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f0=B.f0",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f0).select(1.count),
Seq(LONG),
s1r(1L))
}
@@ -789,10 +644,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f0=B.f1",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f1).select(1.count),
Seq(LONG),
s1r(1L))
}
@@ -802,10 +657,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f1=B.f0",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('b === 'f0).select(1.count),
Seq(LONG),
s1r(1L))
@@ -816,10 +671,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f0=B.f2",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f2).select(1.count),
Seq(LONG),
s1r(1L))
}
@@ -829,10 +684,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f2=B.f0",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('c === 'f0).select(1.count),
Seq(LONG),
s1r(1L))
}
@@ -842,10 +697,10 @@ class DecimalITCase extends BatchTestBase {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f0=B.f3",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f3).select(1.count),
Seq(LONG),
s1r(1L))
}
@@ -854,21 +709,20 @@ class DecimalITCase extends BatchTestBase {
def testJoin7(): Unit = {
tEnv.getConfig.getConfiguration.setString(
ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin")
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE),
s1r(d"1", d"1", 1, 1.0),
- "select count(*) from Table1 A, Table1 B where A.f3=B.f0",
+ table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f3).select(1.count),
Seq(LONG),
s1r(1L))
}
- @Ignore
@Test
def testGroupBy(): Unit = {
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2)),
Seq(row(d"1"), row(d"3"), row(d"1.0"), row(d"2")),
- "select count(*) from Table1 A group by f0",
+ table => table.groupBy('f0).select(1.count),
Seq(LONG),
Seq(row(2L), row(1L), row(1L)))
}
@@ -876,10 +730,10 @@ class DecimalITCase extends BatchTestBase {
@Test
def testOrderBy(): Unit = {
env.setParallelism(1) // set sink parallelism to 1
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(8, 2)),
Seq(row(d"1"), row(d"3"), row(d"1.0"), row(d"2")),
- "select f0 from Table1 A order by f0",
+ table => table.select('f0).orderBy('f0),
Seq(DECIMAL(8, 2)),
Seq(row(d"1.00"), row(d"1.00"), row(d"2.00"), row(d"3.00")),
isSorted = true)
@@ -887,31 +741,29 @@ class DecimalITCase extends BatchTestBase {
@Test
def testSimpleNull(): Unit = {
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
Seq(row(d"100.000", null, null)),
- "select distinct(f0), f1, f2 from (select t1.f0, t1.f1, t1.f2 from Table1 t1 " +
- "union all (SELECT * FROM Table1)) order by f0",
+ table => table.union(table).select('f0, 'f1, 'f2).orderBy('f0),
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
s1r(d"100.000", null, null))
}
- @Ignore
@Test
def testAggAvgGroupBy(): Unit = {
// null
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
(0 until 100).map(_ => row(d"100.000", null, null)),
- "select f0, avg(f1), avg(f2) from Table1 group by f0",
+ table => table.groupBy('f0).select('f0, 'f1.avg, 'f2.avg),
Seq(DECIMAL(6, 3), DECIMAL(38, 6), DECIMAL(38, 10)),
s1r(d"100.000", null, null))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
(0 until 100).map(_ => row(d"100.000", d"100.000", d"1${10}")),
- "select f0, avg(f1), avg(f2) from Table1 group by f0",
+ table => table.groupBy('f0).select('f0, 'f1.avg, 'f2.avg),
Seq(DECIMAL(6, 3), DECIMAL(38, 6), DECIMAL(38, 10)),
s1r(d"100.000", d"100.000000", d"1${10}"))
}
@@ -920,17 +772,17 @@ class DecimalITCase extends BatchTestBase {
def testAggMinGroupBy(): Unit = {
// null
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
(0 until 100).map(_ => row(d"100.000", null, null)),
- "select f0, min(f1), min(f2) from Table1 group by f0",
+ table => table.groupBy('f0).select('f0, 'f1.min, 'f2.min),
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
s1r(d"100.000", null, null))
- checkQuery1(
+ checkQuery(
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
(0 until 100).map(i => row(d"100.000", new JBigDecimal(100 - i), d"1${10}")),
- "select f0, min(f1), min(f2) from Table1 group by f0",
+ table => table.groupBy('f0).select('f0, 'f1.min, 'f2.min),
Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)),
s1r(d"100.000", d"1.000", d"1${10}"))
}