You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/12 02:06:37 UTC

spark git commit: [SQL] Miscellaneous SQL/DF expression changes.

Repository: spark
Updated Branches:
  refs/heads/master 7914c720b -> 337c16d57


[SQL] Miscellaneous SQL/DF expression changes.

SPARK-8201 conditional function: if
SPARK-8205 conditional function: nvl
SPARK-8208 math function: ceiling
SPARK-8210 math function: degrees
SPARK-8211 math function: radians
SPARK-8219 math function: negative
SPARK-8216 math function: rename log -> ln
SPARK-8222 math function: alias power / pow
SPARK-8225 math function: alias sign / signum
SPARK-8228 conditional function: isnull
SPARK-8229 conditional function: isnotnull
SPARK-8250 string function: alias lower/lcase
SPARK-8251 string function: alias upper / ucase

Author: Reynold Xin <rx...@databricks.com>

Closes #6754 from rxin/expressions-misc and squashes the following commits:

35fce15 [Reynold Xin] Removed println.
2647067 [Reynold Xin] Promote to string type.
3c32bbc [Reynold Xin] Fixed if.
de827ac [Reynold Xin] Fixed style
b201cd4 [Reynold Xin] Removed if.
6b21a9b [Reynold Xin] [SQL] Miscellaneous SQL/DF expression changes.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/337c16d5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/337c16d5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/337c16d5

Branch: refs/heads/master
Commit: 337c16d57e40cb4967bf85269baae14745f161db
Parents: 7914c72
Author: Reynold Xin <rx...@databricks.com>
Authored: Thu Jun 11 17:06:21 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jun 11 17:06:21 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    | 20 ++++++--
 .../catalyst/analysis/HiveTypeCoercion.scala    | 30 ++++++++++++
 .../analysis/HiveTypeCoercionSuite.scala        | 13 +++++
 .../ConditionalExpressionSuite.scala            | 43 ++++++++++++++++-
 .../spark/sql/ColumnExpressionSuite.scala       | 16 ++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 29 ++++++-----
 .../apache/spark/sql/MathExpressionsSuite.scala | 51 +++++++++++++++++---
 7 files changed, 175 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a7816e3..45bcbf7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -84,43 +84,51 @@ object FunctionRegistry {
   type FunctionBuilder = Seq[Expression] => Expression
 
   val expressions: Map[String, FunctionBuilder] = Map(
-    // Non aggregate functions
+    // misc non-aggregate functions
     expression[Abs]("abs"),
     expression[CreateArray]("array"),
     expression[Coalesce]("coalesce"),
     expression[Explode]("explode"),
+    expression[If]("if"),
+    expression[IsNull]("isnull"),
+    expression[IsNotNull]("isnotnull"),
+    expression[Coalesce]("nvl"),
     expression[Rand]("rand"),
     expression[Randn]("randn"),
     expression[CreateStruct]("struct"),
     expression[Sqrt]("sqrt"),
 
-    // Math functions
+    // math functions
     expression[Acos]("acos"),
     expression[Asin]("asin"),
     expression[Atan]("atan"),
     expression[Atan2]("atan2"),
     expression[Cbrt]("cbrt"),
     expression[Ceil]("ceil"),
+    expression[Ceil]("ceiling"),
     expression[Cos]("cos"),
     expression[EulerNumber]("e"),
     expression[Exp]("exp"),
     expression[Expm1]("expm1"),
     expression[Floor]("floor"),
     expression[Hypot]("hypot"),
-    expression[Log]("log"),
+    expression[Log]("ln"),
     expression[Log10]("log10"),
     expression[Log1p]("log1p"),
+    expression[UnaryMinus]("negative"),
     expression[Pi]("pi"),
     expression[Log2]("log2"),
     expression[Pow]("pow"),
+    expression[Pow]("power"),
     expression[Rint]("rint"),
+    expression[Signum]("sign"),
     expression[Signum]("signum"),
     expression[Sin]("sin"),
     expression[Sinh]("sinh"),
     expression[Tan]("tan"),
     expression[Tanh]("tanh"),
-    expression[ToDegrees]("todegrees"),
-    expression[ToRadians]("toradians"),
+    expression[ToDegrees]("degrees"),
+    expression[ToRadians]("radians"),
 
     // aggregate functions
     expression[Average]("avg"),
@@ -132,10 +140,12 @@ object FunctionRegistry {
     expression[Sum]("sum"),
 
     // string functions
+    expression[Lower]("lcase"),
     expression[Lower]("lower"),
     expression[StringLength]("length"),
     expression[Substring]("substr"),
     expression[Substring]("substring"),
+    expression[Upper]("ucase"),
     expression[Upper]("upper")
   )
 

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 737905c..6ed1923 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -58,6 +58,15 @@ object HiveTypeCoercion {
     case _ => None
   }
 
+  /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
+  private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
+    findTightestCommonTypeOfTwo(left, right).orElse((left, right) match {
+      case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
+      case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
+      case _ => None
+    })
+  }
+
   /**
    * Find the tightest common type of a set of types by continuously applying
    * `findTightestCommonTypeOfTwo` on these types.
@@ -91,6 +100,7 @@ trait HiveTypeCoercion {
     StringToIntegralCasts ::
     FunctionArgumentConversion ::
     CaseWhenCoercion ::
+    IfCoercion ::
     Division ::
     PropagateTypes ::
     ExpectedInputConversion ::
@@ -653,6 +663,26 @@ trait HiveTypeCoercion {
   }
 
   /**
+   * Coerces the type of different branches of If statement to a common type.
+   */
+  object IfCoercion extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+      // Find tightest common type for If, if the true value and false value have different types.
+      case i @ If(pred, left, right) if left.dataType != right.dataType =>
+        findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType =>
+          val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+          val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+          i.makeCopy(Array(pred, newLeft, newRight))
+        }.getOrElse(i)  // If there is no applicable conversion, leave expression unchanged.
+
+      // Convert If(null literal, _, _) into boolean type.
+      // In the optimizer, we should short-circuit this directly into false value.
+      case i @ If(pred, left, right) if pred.dataType == NullType =>
+        i.makeCopy(Array(Literal.create(null, BooleanType), left, right))
+    }
+  }
+
+  /**
    * Casts types according to the expected input types for Expressions that have the trait
    * `ExpectsInputTypes`.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 9977f7a..f7b8e21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest {
         :: Nil))
   }
 
+  test("type coercion for If") {
+    val rule = new HiveTypeCoercion { }.IfCoercion
+    ruleTest(rule,
+      If(Literal(true), Literal(1), Literal(1L)),
+      If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
+    )
+
+    ruleTest(rule,
+      If(Literal.create(null, NullType), Literal(1), Literal(1)),
+      If(Literal.create(null, BooleanType), Literal(1), Literal(1))
+    )
+  }
+
   test("type coercion for CaseKeyWhen") {
     val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
     ruleTest(cwc,

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 152c4e4..372848e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, BooleanType}
+import org.apache.spark.sql.types._
 
 
 class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
 
+  test("if") {
+    val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)](
+      (true, 1, 2, 1),
+      (false, 1, 2, 2),
+      (null, 1, 2, 2),
+      (true, null, 2, null),
+      (false, 1, null, null),
+      (null, null, 2, 2),
+      (null, 1, null, null)
+    )
+
+    // dataType must match T.
+    def testIf(convert: (Integer => Any), dataType: DataType): Unit = {
+      for ((predicate, trueValue, falseValue, expected) <- testcases) {
+        val trueValueConverted = if (trueValue == null) null else convert(trueValue)
+        val falseValueConverted = if (falseValue == null) null else convert(falseValue)
+        val expectedConverted = if (expected == null) null else convert(expected)
+
+        checkEvaluation(
+          If(Literal.create(predicate, BooleanType),
+            Literal.create(trueValueConverted, dataType),
+            Literal.create(falseValueConverted, dataType)),
+          expectedConverted)
+      }
+    }
+
+    testIf(_ == 1, BooleanType)
+    testIf(_.toShort, ShortType)
+    testIf(identity, IntegerType)
+    testIf(_.toLong, LongType)
+
+    testIf(_.toFloat, FloatType)
+    testIf(_.toDouble, DoubleType)
+    testIf(Decimal(_), DecimalType.Unlimited)
+
+    testIf(identity, DateType)
+    testIf(_.toLong, TimestampType)
+
+    testIf(_.toString, StringType)
+  }
+
   test("case when") {
     val row = create_row(null, false, true, "a", "b", "c")
     val c1 = 'a.boolean.at(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 4f5484f..efcdae5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -185,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest {
     checkAnswer(
       nullStrings.toDF.where($"s".isNull),
       nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
+
+    checkAnswer(
+      ctx.sql("select isnull(null), isnull(1)"),
+      Row(true, false))
   }
 
   test("isNotNull") {
     checkAnswer(
       nullStrings.toDF.where($"s".isNotNull),
       nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
+
+    checkAnswer(
+      ctx.sql("select isnotnull(null), isnotnull('a')"),
+      Row(false, true))
   }
 
   test("===") {
@@ -393,6 +401,10 @@ class ColumnExpressionSuite extends QueryTest {
       testData.select(upper(lit(null))),
       (1 to 100).map(n => Row(null))
     )
+
+    checkAnswer(
+      ctx.sql("SELECT upper('aB'), ucase('cDe')"),
+      Row("AB", "CDE"))
   }
 
   test("lower") {
@@ -410,6 +422,10 @@ class ColumnExpressionSuite extends QueryTest {
       testData.select(lower(lit(null))),
       (1 to 100).map(n => Row(null))
     )
+
+    checkAnswer(
+      ctx.sql("SELECT lower('aB'), lcase('cDe')"),
+      Row("ab", "cde"))
   }
 
   test("monotonicallyIncreasingId") {

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 659b64c..cfd2386 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -110,7 +110,20 @@ class DataFrameFunctionsSuite extends QueryTest {
       testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
   }
 
-  test("length") {
+  test("if function") {
+    val df = Seq((1, 2)).toDF("a", "b")
+    checkAnswer(
+      df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"),
+      Row("one", "not_one"))
+  }
+
+  test("nvl function") {
+    checkAnswer(
+      ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+      Row("x", "y", null))
+  }
+
+  test("string length function") {
     checkAnswer(
       nullStrings.select(strlen($"s"), strlen("s")),
       nullStrings.collect().toSeq.map { r =>
@@ -127,18 +140,4 @@ class DataFrameFunctionsSuite extends QueryTest {
         Row(l)
       })
   }
-
-  test("log2 functions test") {
-    val df = Seq((1, 2)).toDF("a", "b")
-    checkAnswer(
-      df.select(log2("b") + log2("a")),
-      Row(1))
-
-    checkAnswer(
-      ctx.sql("SELECT LOG2(8)"),
-      Row(3))
-    checkAnswer(
-      ctx.sql("SELECT LOG2(null)"),
-      Row(null))
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/337c16d5/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 0a38af2..6561c3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{log => logarithm}
 
 
 private object MathExpressionsTestData {
@@ -151,20 +152,31 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneMathFunction(tanh, math.tanh)
   }
 
-  test("toDeg") {
+  test("toDegrees") {
     testOneToOneMathFunction(toDegrees, math.toDegrees)
+    checkAnswer(
+      ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
+      Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
+    )
   }
 
-  test("toRad") {
+  test("toRadians") {
     testOneToOneMathFunction(toRadians, math.toRadians)
+    checkAnswer(
+      ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
+      Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
+    )
   }
 
   test("cbrt") {
     testOneToOneMathFunction(cbrt, math.cbrt)
   }
 
-  test("ceil") {
+  test("ceil and ceiling") {
     testOneToOneMathFunction(ceil, math.ceil)
+    checkAnswer(
+      ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
+      Row(0.0, 1.0, 2.0))
   }
 
   test("floor") {
@@ -183,12 +195,21 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneMathFunction(expm1, math.expm1)
   }
 
-  test("signum") {
+  test("signum / sign") {
     testOneToOneMathFunction[Double](signum, math.signum)
+
+    checkAnswer(
+      ctx.sql("SELECT sign(10), signum(-11)"),
+      Row(1, -1))
   }
 
-  test("pow") {
+  test("pow / power") {
     testTwoToOneMathFunction(pow, pow, math.pow)
+
+    checkAnswer(
+      ctx.sql("SELECT pow(1, 2), power(2, 1)"),
+      Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
+    )
   }
 
   test("hypot") {
@@ -199,8 +220,12 @@ class MathExpressionsSuite extends QueryTest {
     testTwoToOneMathFunction(atan2, atan2, math.atan2)
   }
 
-  test("log") {
+  test("log / ln") {
     testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
+    checkAnswer(
+      ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
+      Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
+    )
   }
 
   test("log10") {
@@ -211,4 +236,18 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneNonNegativeMathFunction(log1p, math.log1p)
   }
 
+  test("log2") {
+    val df = Seq((1, 2)).toDF("a", "b")
+    checkAnswer(
+      df.select(log2("b") + log2("a")),
+      Row(1))
+
+    checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
+  }
+
+  test("negative") {
+    checkAnswer(
+      ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
+      Row(-1, 0, 1))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org