You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2018/10/18 14:09:56 UTC

[GitHub] pnowojski closed pull request #6736: [FLINK-10398][table] Add Tanh math function supported in Table API and SQL

pnowojski closed pull request #6736: [FLINK-10398][table] Add Tanh math function supported in Table API and SQL
URL: https://github.com/apache/flink/pull/6736
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/dev/table/functions.md b/docs/dev/table/functions.md
index b6d8a7ec7d9..37171a2079f 100644
--- a/docs/dev/table/functions.md
+++ b/docs/dev/table/functions.md
@@ -1219,6 +1219,18 @@ TAN(numeric)
       </td>
     </tr>
 
+    <tr>
+      <td>
+        {% highlight text %}
+TANH(numeric)
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns the hyperbolic tangent of <i>numeric</i>.</p> 
+        <p>The return type is <i>DOUBLE</i>.</p>
+      </td>
+    </tr>
+
     <tr>
       <td>
         {% highlight text %}
@@ -1666,6 +1678,18 @@ NUMERIC.tan()
       </td>
     </tr>
 
+    <tr>
+      <td>
+        {% highlight java %}
+NUMERIC.tanh()
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns the hyperbolic tangent of <i>NUMERIC</i>.</p> 
+        <p>The return type is <i>DOUBLE</i>.</p>
+      </td>
+    </tr>
+
     <tr>
       <td>
         {% highlight java %}
@@ -2114,6 +2138,18 @@ NUMERIC.tan()
       </td>
     </tr>
 
+    <tr>
+      <td>
+        {% highlight scala %}
+NUMERIC.tanh()
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns the hyperbolic tangent of <i>NUMERIC</i>.</p> 
+        <p>The return type is <i>DOUBLE</i>.</p>
+      </td>
+    </tr>
+
     <tr>
       <td>
         {% highlight scala %}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index e1947c37acf..7c585674b54 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -397,6 +397,11 @@ trait ImplicitExpressionOperations {
     */
   def atan() = Atan(expr)
 
+  /**
+    * Calculates the hyperbolic tangent of a given number.
+    */
+  def tanh() = Tanh(expr)
+
   /**
     * Converts numeric from radians to degrees.
     */
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
index 7781b57825e..1ae6e39e073 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala
@@ -72,6 +72,9 @@ object BuiltInMethods {
   val TAN = Types.lookupMethod(classOf[Math], "tan", classOf[Double])
   val TAN_DEC = Types.lookupMethod(classOf[SqlFunctions], "tan", classOf[JBigDecimal])
 
+  val TANH = Types.lookupMethod(classOf[Math], "tanh", classOf[Double])
+  val TANH_DEC = Types.lookupMethod(classOf[ScalarFunctions], "tanh", classOf[JBigDecimal])
+
   val COT = Types.lookupMethod(classOf[SqlFunctions], "cot", classOf[Double])
   val COT_DEC = Types.lookupMethod(classOf[SqlFunctions], "cot", classOf[JBigDecimal])
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
index 6404813702f..9a6aeb15b3b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala
@@ -331,6 +331,18 @@ object FunctionGenerator {
     DOUBLE_TYPE_INFO,
     BuiltInMethods.TAN_DEC)
 
+  addSqlFunctionMethod(
+    TANH,
+    Seq(DOUBLE_TYPE_INFO),
+    DOUBLE_TYPE_INFO,
+    BuiltInMethods.TANH)
+
+  addSqlFunctionMethod(
+    TANH,
+    Seq(BIG_DEC_TYPE_INFO),
+    DOUBLE_TYPE_INFO,
+    BuiltInMethods.TANH_DEC)
+
   addSqlFunctionMethod(
     COT,
     Seq(DOUBLE_TYPE_INFO),
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
index 6067130cd99..97e7190c0eb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala
@@ -217,6 +217,20 @@ case class Tan(child: Expression) extends UnaryExpression {
   }
 }
 
+case class Tanh(child: Expression) extends UnaryExpression {
+
+  override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO
+
+  override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = {
+    relBuilder.call(ScalarSqlFunctions.TANH, child.toRexNode)
+  }
+
+  override private[flink] def validateInput(): ValidationResult =
+    TypeCheckUtils.assertNumericExpr(child.resultType, "Tanh")
+
+  override def toString = s"tanh($child)"
+}
+
 case class Cot(child: Expression) extends UnaryExpression {
   override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ScalarSqlFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ScalarSqlFunctions.scala
index 0f594bbcd6c..85f47e8a7a1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ScalarSqlFunctions.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ScalarSqlFunctions.scala
@@ -34,6 +34,14 @@ object ScalarSqlFunctions {
     OperandTypes.NILADIC,
     SqlFunctionCategory.NUMERIC)
 
+  val TANH = new SqlFunction(
+    "TANH",
+    SqlKind.OTHER_FUNCTION,
+    ReturnTypes.DOUBLE_NULLABLE,
+    null,
+    OperandTypes.NUMERIC,
+    SqlFunctionCategory.NUMERIC)
+
   val BIN = new SqlFunction(
     "BIN",
     SqlKind.OTHER_FUNCTION,
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/functions/ScalarFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/functions/ScalarFunctions.scala
index 1db4da2f899..91269e702b2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/functions/ScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/functions/ScalarFunctions.scala
@@ -114,6 +114,13 @@ object ScalarFunctions {
     }
   }
 
+  /**
+    * Calculates the hyperbolic tangent of a big decimal number.
+    */
+  def tanh(x: JBigDecimal): Double = {
+    Math.tanh(x.doubleValue())
+  }
+
   /**
     * Returns the logarithm of "x" with base "base".
     */
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 26a20ff41a4..5a39014a6fd 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -232,6 +232,7 @@ object FunctionCatalog {
     "sin" -> classOf[Sin],
     "cos" -> classOf[Cos],
     "tan" -> classOf[Tan],
+    "tanh" -> classOf[Tanh],
     "cot" -> classOf[Cot],
     "asin" -> classOf[Asin],
     "acos" -> classOf[Acos],
@@ -432,6 +433,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
     SqlStdOperatorTable.SIN,
     SqlStdOperatorTable.COS,
     SqlStdOperatorTable.TAN,
+    ScalarSqlFunctions.TANH,
     SqlStdOperatorTable.COT,
     SqlStdOperatorTable.ASIN,
     SqlStdOperatorTable.ACOS,
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
index 6422dd1ebdd..02baafd8382 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarFunctionsTest.scala
@@ -1444,6 +1444,45 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
       math.tan(-1231.1231231321321321111).toString)
   }
 
+  @Test
+  def testTanh(): Unit = {
+    testAllApis(
+      0.tanh(),
+      "0.tanh()",
+      "TANH(0)",
+      math.tanh(0).toString)
+
+    testAllApis(
+      -1.tanh(),
+      "-1.tanh()",
+      "TANH(-1)",
+      math.tanh(-1).toString)
+
+    testAllApis(
+      'f4.tanh(),
+      "f4.tanh",
+      "TANH(f4)",
+      math.tanh(44L).toString)
+
+    testAllApis(
+      'f6.tanh(),
+      "f6.tanh",
+      "TANH(f6)",
+      math.tanh(4.6D).toString)
+
+    testAllApis(
+      'f7.tanh(),
+      "f7.tanh",
+      "TANH(f7)",
+      math.tanh(3).toString)
+
+    testAllApis(
+      'f22.tanh(),
+      "f22.tanh",
+      "TANH(f22)",
+      math.tanh(2.0).toString)
+  }
+
   @Test
   def testCot(): Unit = {
     testAllApis(
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
index 93285d0d962..88cd78adfd9 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
@@ -111,6 +111,7 @@ class SqlExpressionTest extends ExpressionTestBase {
     testSqlApi("SIN(2.5)", "0.5984721441039564")
     testSqlApi("COS(2.5)", "-0.8011436155469337")
     testSqlApi("TAN(2.5)", "-0.7470222972386603")
+    testSqlApi("TANH(2.5)", "0.9866142981514303")
     testSqlApi("COT(2.5)", "-1.3386481283041514")
     testSqlApi("ASIN(0.5)", "0.5235987755982989")
     testSqlApi("ACOS(0.5)", "1.0471975511965979")


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services