You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/18 22:01:00 UTC

spark git commit: [SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression

Repository: spark
Updated Branches:
  refs/heads/master ddc5baf17 -> 31641128b


[SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression

JIRA: https://issues.apache.org/jira/browse/SPARK-8363

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #6823 from viirya/move_sqrt and squashes the following commits:

8977e11 [Liang-Chi Hsieh] Remove unnecessary old tests.
d23e79e [Liang-Chi Hsieh] Explicitly indicate sqrt value sequence.
699f48b [Liang-Chi Hsieh] Use correct @since tag.
8dff6d1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into move_sqrt
bc2ed77 [Liang-Chi Hsieh] Remove/move arithmetic expression test and expression type checking test. Remove unnecessary Sqrt type rule.
d38492f [Liang-Chi Hsieh] Now sqrt accepts boolean because type casting is handled by HiveTypeCoercion.
297cc90 [Liang-Chi Hsieh] Sqrt only accepts double input.
ef4a21a [Liang-Chi Hsieh] Move sqrt to math.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31641128
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31641128
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31641128

Branch: refs/heads/master
Commit: 31641128b34d6f2aa7cb67324c24dd8b3ed84689
Parents: ddc5baf
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Jun 18 13:00:31 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jun 18 13:00:31 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  1 -
 .../sql/catalyst/expressions/arithmetic.scala   | 32 --------------------
 .../spark/sql/catalyst/expressions/math.scala   |  2 ++
 .../expressions/ArithmeticExpressionSuite.scala | 15 ---------
 .../ExpressionTypeCheckingSuite.scala           |  2 --
 .../expressions/MathFunctionsSuite.scala        | 10 ++++++
 .../scala/org/apache/spark/sql/functions.scala  | 10 +++++-
 .../apache/spark/sql/MathExpressionsSuite.scala | 10 ++++++
 8 files changed, 31 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 189451d..8012b22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -307,7 +307,6 @@ trait HiveTypeCoercion {
 
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
-      case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 167e460..ace8427 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
   protected override def evalInternal(evalE: Any) = evalE
 }
 
-case class Sqrt(child: Expression) extends UnaryArithmetic {
-  override def dataType: DataType = DoubleType
-  override def nullable: Boolean = true
-  override def toString: String = s"SQRT($child)"
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
-
-  private lazy val numeric = TypeUtils.getNumeric(child.dataType)
-
-  protected override def evalInternal(evalE: Any) = {
-    val value = numeric.toDouble(evalE)
-    if (value < 0) null
-    else math.sqrt(value)
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        if (${eval.primitive} < 0.0) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
-        }
-      }
-    """
-  }
-}
-
 /**
  * A function that get the absolute value of the numeric value.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 67cb0b5..3b83c6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
 
 case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
 
+case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+
 case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
 
 case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 3f48432..4bbbbe6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
     checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
   }
-
-  test("SQRT") {
-    val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
-    val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
-    val rowSequence = inputSequence.map(l => create_row(l.toDouble))
-    val d = 'a.double.at(0)
-
-    for ((row, expected) <- rowSequence zip expectedResults) {
-      checkEvaluation(Sqrt(d), expected, row)
-    }
-
-    checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
-    checkEvaluation(Sqrt(-1), null, EmptyRow)
-    checkEvaluation(Sqrt(-1.5), null, EmptyRow)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
index dcb3635..49b1119 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
@@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
   test("check types for unary arithmetic") {
     assertError(UnaryMinus('stringField), "operator - accepts numeric type")
-    assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt
-    assertError(Sqrt('booleanField), "function sqrt accepts numeric type")
     assertError(Abs('stringField), "function abs accepts numeric type")
     assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 0050ad3..21e9b92 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.types.DoubleType
 
@@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
   }
 
+  test("sqrt") {
+    testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
+    testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
+
+    checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
+    checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
+    checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
+  }
+
   test("pow") {
     testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
     testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index dff0932..d8a91be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -707,12 +707,20 @@ object functions {
   /**
    * Computes the square root of the specified float value.
    *
-   * @group normal_funcs
+   * @group math_funcs
    * @since 1.3.0
    */
   def sqrt(e: Column): Column = Sqrt(e.expr)
 
   /**
+   * Computes the square root of the specified float value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def sqrt(colName: String): Column = sqrt(Column(colName))
+
+  /**
    * Creates a new struct column. The input column must be a column in a [[DataFrame]], or
    * a derived column expression that is named (i.e. aliased).
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 7c9c121..2768d7d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -270,6 +270,16 @@ class MathExpressionsSuite extends QueryTest {
     checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
   }
 
+  test("sqrt") {
+    val df = Seq((1, 4)).toDF("a", "b")
+    checkAnswer(
+      df.select(sqrt("a"), sqrt("b")),
+      Row(1.0, 2.0))
+
+    checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
+    checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
+  }
+
   test("negative") {
     checkAnswer(
       ctx.sql("SELECT negative(1), negative(0), negative(-1)"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org