You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/01/30 08:00:32 UTC

[flink] branch master updated: [FLINK-11296][table] Support truncate in TableAPI and SQL

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new e154375  [FLINK-11296][table] Support truncate in TableAPI and SQL
e154375 is described below

commit e15437529c86f1ec9d20b287b93ddfc7bcc02666
Author: XuQianJin-Stars <x1...@163.com>
AuthorDate: Thu Jan 10 18:06:59 2019 +0800

    [FLINK-11296][table] Support truncate in TableAPI and SQL
    
    This closes #7450
---
 docs/dev/table/functions.md                        | 26 +++++++
 .../flink/table/api/scala/expressionDsl.scala      | 14 ++++
 .../flink/table/codegen/ExpressionReducer.scala    |  6 +-
 .../flink/table/codegen/calls/BuiltInMethods.scala | 18 +++++
 .../table/codegen/calls/FunctionGenerator.scala    | 48 ++++++++++++
 .../flink/table/expressions/mathExpressions.scala  | 33 ++++++++
 .../flink/table/validate/FunctionCatalog.scala     |  2 +
 .../table/expressions/ScalarFunctionsTest.scala    | 88 ++++++++++++++++++++++
 .../table/expressions/SqlExpressionTest.scala      |  2 +
 .../expressions/utils/ScalarTypesTestBase.scala    |  6 +-
 .../validation/ScalarFunctionsValidationTest.scala | 37 +++++++++
 11 files changed, 277 insertions(+), 3 deletions(-)

diff --git a/docs/dev/table/functions.md b/docs/dev/table/functions.md
index f41139d..d12bb74 100644
--- a/docs/dev/table/functions.md
+++ b/docs/dev/table/functions.md
@@ -1454,6 +1454,19 @@ HEX(string)
         <p>E.g. a numeric 20 leads to "14", a numeric 100 leads to "64", a string "hello,world" leads to "68656C6C6F2C776F726C64".</p>
       </td>
     </tr>
+        
+    <tr>
+      <td>
+        {% highlight text %}
+TRUNCATE(numeric1, integer2)
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns a <i>numeric</i> of truncated to <i>integer2</i> decimal places. Returns NULL if <i>numeric1</i> or <i>integer2</i> is NULL.If <i>integer2</i> is 0,the result has no decimal point or fractional part.<i>integer2</i> can be negative to cause <i>integer2</i> digits left of the decimal point of the value to become zero.This function can also pass in only one <i>numeric1</i> parameter and not set <i>Integer2</i> to use.If <i>Integer2</i> is not set, the function truncates a [...]
+        <p>E.g. <code>truncate(42.345, 2)</code> to 42.34. and <code>truncate(42.345)</code> to 42.0.</p>
+      </td>
+    </tr>
+        
   </tbody>
 </table>
 </div>
@@ -1926,6 +1939,19 @@ STRING.hex()
       <p>E.g. a numeric 20 leads to "14", a numeric 100 leads to "64", a string "hello,world" leads to "68656C6C6F2C776F726C64".</p>
     </td>
    </tr>
+ 
+       <tr>
+         <td>
+           {% highlight text %}
+numeric1.truncate(INTEGER2)
+   {% endhighlight %}
+         </td>
+         <td>
+           <p>Returns a <i>numeric</i> of truncated to <i>integer2</i> decimal places. Returns NULL if <i>numeric1</i> or <i>integer2</i> is NULL.If <i>integer2</i> is 0,the result has no decimal point or fractional part.<i>integer2</i> can be negative to cause <i>integer2</i> digits left of the decimal point of the value to become zero.This function can also pass in only one <i>numeric1</i> parameter and not set <i>Integer2</i> to use.If <i>Integer2</i> is not set, the function truncate [...]
+           <p>E.g. <code>42.324.truncate(2)</code> to 42.34. and <code>42.324.truncate()</code> to 42.0.</p>
+         </td>
+       </tr>
+   
   </tbody>
 </table>
 </div>
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index 390960c..54f2514 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -442,6 +442,20 @@ trait ImplicitExpressionOperations {
     */
   def hex() = Hex(expr)
 
+  /**
+    * Returns a number of truncated to n decimal places.
+    * If n is 0,the result has no decimal point or fractional part.
+    * n can be negative to cause n digits left of the decimal point of the value to become zero.
+    * E.g. truncate(42.345, 2) to 42.34.
+    */
+  def truncate(n: Expression) = Truncate(expr, n)
+
+  /**
+    * Returns a number of truncated to 0 decimal places.
+    * E.g. truncate(42.345) to 42.0.
+    */
+  def truncate() = Truncate(expr)
+
   // String operations
 
   /**
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
index 2b50bb9..dfed70b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
@@ -127,7 +127,11 @@ class ExpressionReducer(config: TableConfig)
           val reducedValue = reduced.getField(reducedIdx)
           // RexBuilder handle double literal incorrectly, convert it into BigDecimal manually
           val value = if (unreduced.getType.getSqlTypeName == SqlTypeName.DOUBLE) {
-            new java.math.BigDecimal(reducedValue.asInstanceOf[Number].doubleValue())
+            if (reducedValue == null) {
+              reducedValue
+            } else {
+              new java.math.BigDecimal(reducedValue.asInstanceOf[Number].doubleValue())
+            }
           } else {
             reducedValue
           }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
index 8abe55d..4384b03 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
@@ -176,4 +176,22 @@ object BuiltInMethods {
     "repeat",
     classOf[String],
     classOf[Int])
+
+  val TRUNCATE_DOUBLE_ONE = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Double])
+  val TRUNCATE_INT_ONE = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Int])
+  val TRUNCATE_LONG_ONE = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Long])
+  val TRUNCATE_DEC_ONE = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[JBigDecimal])
+
+  val TRUNCATE_DOUBLE = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Double], classOf[Int])
+  val TRUNCATE_INT = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Int], classOf[Int])
+  val TRUNCATE_LONG = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[Long], classOf[Int])
+  val TRUNCATE_DEC = Types.lookupMethod(classOf[SqlFunctions], "struncate",
+    classOf[JBigDecimal], classOf[Int])
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
index c707040..c834925 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
@@ -552,6 +552,54 @@ object FunctionGenerator {
     STRING_TYPE_INFO,
     BuiltInMethods.HEX_STRING)
 
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(LONG_TYPE_INFO),
+    LONG_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_LONG_ONE)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(INT_TYPE_INFO),
+    INT_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_INT_ONE)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(BIG_DEC_TYPE_INFO),
+    BIG_DEC_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_DEC_ONE)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(DOUBLE_TYPE_INFO),
+    DOUBLE_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_DOUBLE_ONE)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(LONG_TYPE_INFO, INT_TYPE_INFO),
+    LONG_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_LONG)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(INT_TYPE_INFO, INT_TYPE_INFO),
+    INT_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_INT)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(BIG_DEC_TYPE_INFO, INT_TYPE_INFO),
+    BIG_DEC_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_DEC)
+
+  addSqlFunctionMethod(
+    TRUNCATE,
+    Seq(DOUBLE_TYPE_INFO, INT_TYPE_INFO),
+    DOUBLE_TYPE_INFO,
+    BuiltInMethods.TRUNCATE_DOUBLE)
+
   // ----------------------------------------------------------------------------------------------
   // Temporal functions
   // ----------------------------------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
index 05539de..da214f9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
@@ -493,3 +493,36 @@ case class UUID() extends LeafExpression {
     relBuilder.call(ScalarSqlFunctions.UUID)
   }
 }
+
+case class Truncate(base: Expression, num: Expression) extends Expression with InputTypeSpec {
+  def this(base: Expression) = this(base, null)
+
+  override private[flink] def resultType: TypeInformation[_] = base.resultType
+
+  override private[flink] def children: Seq[Expression] =
+    if (num == null) Seq(base) else Seq(base, num)
+
+  override private[flink] def expectedTypes: Seq[TypeInformation[_]] =
+    if (num == null) Seq(base.resultType) else Seq(base.resultType, INT_TYPE_INFO)
+
+  override def toString: String = s"truncate(${children.mkString(",")})"
+
+  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+    relBuilder.call(SqlStdOperatorTable.TRUNCATE, children.map(_.toRexNode))
+  }
+
+  override private[flink] def validateInput(): ValidationResult = {
+    if (num != null) {
+      if (!TypeCheckUtils.isInteger(num.resultType)) {
+        ValidationFailure(s"truncate num requires int, get " +
+          s"$num : ${num.resultType}")
+      }
+    }
+    TypeCheckUtils.assertNumericExpr(base.resultType, s"truncate base :$base")
+  }
+}
+
+object Truncate {
+  def apply(base: Expression): Truncate = Truncate(base, null)
+}
+
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 18cb806..658d94c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -253,6 +253,7 @@ object FunctionCatalog {
     "randInteger" -> classOf[RandInteger],
     "bin" -> classOf[Bin],
     "hex" -> classOf[Hex],
+    "truncate" -> classOf[Truncate],
 
     // temporal functions
     "extract" -> classOf[Extract],
@@ -478,6 +479,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
     ScalarSqlFunctions.RTRIM,
     ScalarSqlFunctions.REPEAT,
     ScalarSqlFunctions.REGEXP_REPLACE,
+    SqlStdOperatorTable.TRUNCATE,
 
     // MATCH_RECOGNIZE
     SqlStdOperatorTable.FIRST,
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
index 23bd2ae..799f636 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
@@ -1973,6 +1973,94 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
     )
   }
 
+  @Test
+  def testTruncate(): Unit = {
+    testAllApis(
+      'f29.truncate('f30),
+      "f29.truncate(f30)",
+      "truncate(f29, f30)",
+      "0.4")
+
+    testAllApis(
+      'f31.truncate('f7),
+      "f31.truncate(f7)",
+      "truncate(f31, f7)",
+      "-0.123")
+
+    testAllApis(
+      'f4.truncate('f32),
+      "f4.truncate(f32)",
+      "truncate(f4, f32)",
+      "40")
+
+    testAllApis(
+      'f28.cast(Types.DOUBLE).truncate(1),
+      "f28.cast(DOUBLE).truncate(1)",
+      "truncate(cast(f28 as DOUBLE), 1)",
+      "0.4")
+
+    testAllApis(
+      'f31.cast(Types.DECIMAL).truncate(2),
+      "f31.cast(DECIMAL).truncate(2)",
+      "truncate(cast(f31 as decimal), 2)",
+      "-0.12")
+
+    testAllApis(
+      'f36.cast(Types.DECIMAL).truncate(),
+      "f36.cast(DECIMAL).truncate()",
+      "truncate(42.324)",
+      "42")
+
+    testAllApis(
+      'f5.cast(Types.FLOAT).truncate(),
+      "f5.cast(FLOAT).truncate()",
+      "truncate(cast(f5 as float))",
+      "4.0")
+
+    testAllApis(
+      42.truncate(-1),
+      "42.truncate(-1)",
+      "truncate(42, -1)",
+      "40")
+
+    testAllApis(
+      42.truncate(-3),
+      "42.truncate(-3)",
+      "truncate(42, -3)",
+      "0")
+
+    //    The validation parameter is null
+    testAllApis(
+      'f33.cast(Types.INT).truncate(1),
+      "f33.cast(INT).truncate(1)",
+      "truncate(cast(null as integer), 1)",
+      "null")
+
+    testAllApis(
+      43.21.truncate('f33.cast(Types.INT)),
+      "43.21.truncate(f33.cast(INT))",
+      "truncate(43.21, cast(null as integer))",
+      "null")
+
+    testAllApis(
+      'f33.cast(Types.DOUBLE).truncate(1),
+      "f33.cast(DOUBLE).truncate(1)",
+      "truncate(cast(null as double), 1)",
+      "null")
+
+    testAllApis(
+      'f33.cast(Types.INT).truncate(1),
+      "f33.cast(INT).truncate(1)",
+      "truncate(cast(null as integer))",
+      "null")
+
+    testAllApis(
+      'f33.cast(Types.DOUBLE).truncate(),
+      "f33.cast(DOUBLE).truncate()",
+      "truncate(cast(null as double))",
+      "null")
+  }
+
   // ----------------------------------------------------------------------------------------------
   // Temporal functions
   // ----------------------------------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
index 7f897d0..4e69de9 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
@@ -126,6 +126,8 @@ class SqlExpressionTest extends ExpressionTestBase {
     testSqlApi("PI", "3.141592653589793")
     testSqlApi("E()", "2.718281828459045")
     testSqlApi("BIN(12)", "1100")
+    testSqlApi("truncate(42.345)", "42")
+    testSqlApi("truncate(cast(42.345 as decimal(2, 3)), 2)", "42.34")
   }
 
   @Test
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala
index 6ad59b1..2543657 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala
@@ -28,7 +28,7 @@ import org.apache.flink.types.Row
 class ScalarTypesTestBase extends ExpressionTestBase {
 
   def testData: Row = {
-    val testData = new Row(36)
+    val testData = new Row(37)
     testData.setField(0, "This is a test String.")
     testData.setField(1, true)
     testData.setField(2, 42.toByte)
@@ -65,6 +65,7 @@ class ScalarTypesTestBase extends ExpressionTestBase {
     testData.setField(33, null)
     testData.setField(34, 256)
     testData.setField(35, "aGVsbG8gd29ybGQ=")
+    testData.setField(36, BigDecimal("42.345").bigDecimal)
     testData
   }
 
@@ -105,6 +106,7 @@ class ScalarTypesTestBase extends ExpressionTestBase {
       Types.INT,
       Types.STRING,
       Types.INT,
-      Types.STRING).asInstanceOf[TypeInformation[Any]]
+      Types.STRING,
+      Types.DECIMAL).asInstanceOf[TypeInformation[Any]]
   }
 }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala
index 8a5691d..2d9186b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala
@@ -21,6 +21,7 @@ package org.apache.flink.table.expressions.validation
 import org.apache.calcite.avatica.util.TimeUnit
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.api.{SqlParserException, ValidationException}
+import org.apache.flink.table.codegen.CodeGenException
 import org.apache.flink.table.expressions.TimePointUnit
 import org.apache.flink.table.expressions.utils.ScalarTypesTestBase
 import org.junit.Test
@@ -64,6 +65,42 @@ class ScalarFunctionsValidationTest extends ScalarTypesTestBase {
     testSqlApi("BIN(f16)", "101010") // Date type
   }
 
+  @Test(expected = classOf[ValidationException])
+  def testInvalidTruncate1(): Unit = {
+    // All arguments are string type
+    testSqlApi(
+      "TRUNCATE('abc', 'def')",
+      "FAIL")
+
+    // The second argument is of type String
+    testSqlApi(
+      "TRUNCATE(f12, f0)",
+      "FAIL")
+
+    // The second argument is of type Float
+    testSqlApi(
+      "TRUNCATE(f12,f12)",
+      "FAIL")
+
+    // The second argument is of type Double
+    testSqlApi(
+      "TRUNCATE(f12, cast(f28 as DOUBLE))",
+      "FAIL")
+
+    // The second argument is of type BigDecimal
+    testSqlApi(
+      "TRUNCATE(f12,f15)",
+      "FAIL")
+  }
+
+  @Test(expected = classOf[CodeGenException])
+  def testInvalidTruncate2(): Unit = {
+    // The one argument is of type String
+    testSqlApi(
+      "TRUNCATE('abc')",
+      "FAIL")
+  }
+
   // ----------------------------------------------------------------------------------------------
   // String functions
   // ----------------------------------------------------------------------------------------------