You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/13 09:14:37 UTC

spark git commit: [SPARK-8203] [SPARK-8204] [SQL] conditional function: least/greatest

Repository: spark
Updated Branches:
  refs/heads/master 20b474335 -> 92540d22e


[SPARK-8203] [SPARK-8204] [SQL] conditional function: least/greatest

chenghao-intel zhichao-li qiansl127

Author: Daoyuan Wang <da...@intel.com>

Closes #6851 from adrian-wang/udflg and squashes the following commits:

0f1bff2 [Daoyuan Wang] address comments from davis
7a6bdbb [Daoyuan Wang] add '.' for hex()
c1f6824 [Daoyuan Wang] add codegen, test for all types
ec625b0 [Daoyuan Wang] conditional function: least/greatest


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92540d22
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92540d22
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92540d22

Branch: refs/heads/master
Commit: 92540d22e45f9300f413f520a1770e9f36cfa833
Parents: 20b4743
Author: Daoyuan Wang <da...@intel.com>
Authored: Mon Jul 13 00:14:32 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Jul 13 00:14:32 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   2 +
 .../sql/catalyst/expressions/conditionals.scala | 103 ++++++++++++++++++-
 .../ConditionalExpressionSuite.scala            |  81 +++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  60 ++++++++++-
 .../spark/sql/DataFrameFunctionsSuite.scala     |  22 ++++
 5 files changed, 263 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92540d22/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f62d79f..ed69c42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -76,9 +76,11 @@ object FunctionRegistry {
     expression[CreateArray]("array"),
     expression[Coalesce]("coalesce"),
     expression[Explode]("explode"),
+    expression[Greatest]("greatest"),
     expression[If]("if"),
     expression[IsNull]("isnull"),
     expression[IsNotNull]("isnotnull"),
+    expression[Least]("least"),
     expression[Coalesce]("nvl"),
     expression[Rand]("rand"),
     expression[Randn]("randn"),

http://git-wip-us.apache.org/repos/asf/spark/blob/92540d22/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 395e84f..e6a705f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BooleanType, DataType}
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types.{NullType, BooleanType, DataType}
 
 
 case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
@@ -312,3 +313,103 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
     }.mkString
   }
 }
+
+case class Least(children: Expression*) extends Expression {
+  require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)
+
+  override def nullable: Boolean = children.forall(_.nullable)
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  private lazy val ordering = TypeUtils.getOrdering(dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+      TypeCheckResult.TypeCheckFailure(
+        s"The expressions should all have the same type," +
+          s" got LEAST (${children.map(_.dataType)}).")
+    } else {
+      TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+    }
+  }
+
+  override def dataType: DataType = children.head.dataType
+
+  override def eval(input: InternalRow): Any = {
+    children.foldLeft[Any](null)((r, c) => {
+      val evalc = c.eval(input)
+      if (evalc != null) {
+        if (r == null || ordering.lt(evalc, r)) evalc else r
+      } else {
+        r
+      }
+    })
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val evalChildren = children.map(_.gen(ctx))
+    def updateEval(i: Int): String =
+      s"""
+        if (!${evalChildren(i).isNull} && (${ev.isNull} ||
+          ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) {
+          ${ev.isNull} = false;
+          ${ev.primitive} = ${evalChildren(i).primitive};
+        }
+      """
+    s"""
+      ${evalChildren.map(_.code).mkString("\n")}
+      boolean ${ev.isNull} = true;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      ${(0 until children.length).map(updateEval).mkString("\n")}
+    """
+  }
+}
+
+case class Greatest(children: Expression*) extends Expression {
+  require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)
+
+  override def nullable: Boolean = children.forall(_.nullable)
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  private lazy val ordering = TypeUtils.getOrdering(dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+      TypeCheckResult.TypeCheckFailure(
+        s"The expressions should all have the same type," +
+          s" got GREATEST (${children.map(_.dataType)}).")
+    } else {
+      TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+    }
+  }
+
+  override def dataType: DataType = children.head.dataType
+
+  override def eval(input: InternalRow): Any = {
+    children.foldLeft[Any](null)((r, c) => {
+      val evalc = c.eval(input)
+      if (evalc != null) {
+        if (r == null || ordering.gt(evalc, r)) evalc else r
+      } else {
+        r
+      }
+    })
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val evalChildren = children.map(_.gen(ctx))
+    def updateEval(i: Int): String =
+      s"""
+        if (!${evalChildren(i).isNull} && (${ev.isNull} ||
+          ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) {
+          ${ev.isNull} = false;
+          ${ev.primitive} = ${evalChildren(i).primitive};
+        }
+      """
+    s"""
+      ${evalChildren.map(_.code).mkString("\n")}
+      boolean ${ev.isNull} = true;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      ${(0 until children.length).map(updateEval).mkString("\n")}
+    """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/92540d22/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 372848e..aaf40cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.sql.{Timestamp, Date}
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
 
@@ -134,4 +137,82 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
   }
 
+  test("function least") {
+    val row = create_row(1, 2, "a", "b", "c")
+    val c1 = 'a.int.at(0)
+    val c2 = 'a.int.at(1)
+    val c3 = 'a.string.at(2)
+    val c4 = 'a.string.at(3)
+    val c5 = 'a.string.at(4)
+    checkEvaluation(Least(c4, c3, c5), "a", row)
+    checkEvaluation(Least(c1, c2), 1, row)
+    checkEvaluation(Least(c1, c2, Literal(-1)), -1, row)
+    checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row)
+
+    checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty)
+    checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty)
+    checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty)
+    checkEvaluation(
+      Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty)
+    checkEvaluation(
+      Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty)
+    checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty)
+    checkEvaluation(
+      Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty)
+    checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty)
+    checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty)
+    checkEvaluation(
+      Least(
+        Literal(BigDecimal("1234567890987654321123456")),
+        Literal(BigDecimal("1234567890987654321123458"))),
+      BigDecimal("1234567890987654321123456"), InternalRow.empty)
+    checkEvaluation(
+      Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
+      Date.valueOf("2015-01-01"), InternalRow.empty)
+    checkEvaluation(
+      Least(
+        Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
+        Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
+      Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty)
+  }
+
+  test("function greatest") {
+    val row = create_row(1, 2, "a", "b", "c")
+    val c1 = 'a.int.at(0)
+    val c2 = 'a.int.at(1)
+    val c3 = 'a.string.at(2)
+    val c4 = 'a.string.at(3)
+    val c5 = 'a.string.at(4)
+    checkEvaluation(Greatest(c4, c5, c3), "c", row)
+    checkEvaluation(Greatest(c2, c1), 2, row)
+    checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row)
+    checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row)
+
+    checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty)
+    checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty)
+    checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty)
+    checkEvaluation(
+      Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty)
+    checkEvaluation(
+      Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty)
+    checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty)
+    checkEvaluation(
+      Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty)
+    checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty)
+    checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty)
+    checkEvaluation(
+      Greatest(
+        Literal(BigDecimal("1234567890987654321123456")),
+        Literal(BigDecimal("1234567890987654321123458"))),
+      BigDecimal("1234567890987654321123458"), InternalRow.empty)
+    checkEvaluation(
+      Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
+      Date.valueOf("2015-07-01"), InternalRow.empty)
+    checkEvaluation(
+      Greatest(
+        Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
+        Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
+      Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/92540d22/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 08bf37a..ffa52f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -599,7 +599,7 @@ object functions {
   /**
    * Creates a new row for each element in the given array or map column.
    */
-   def explode(e: Column): Column = Explode(e.expr)
+  def explode(e: Column): Column = Explode(e.expr)
 
   /**
    * Converts a string exprsesion to lower case.
@@ -1073,15 +1073,41 @@ object functions {
   def floor(columnName: String): Column = floor(Column(columnName))
 
   /**
-   * Computes hex value of the given column
+   * Returns the greatest value of the list of values.
    *
-   * @group math_funcs
+   * @group normal_funcs
    * @since 1.5.0
    */
+  @scala.annotation.varargs
+  def greatest(exprs: Column*): Column = if (exprs.length < 2) {
+    sys.error("GREATEST takes at least 2 parameters")
+  } else {
+    Greatest(exprs.map(_.expr): _*)
+  }
+
+  /**
+   * Returns the greatest value of the list of column names.
+   *
+   * @group normal_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
+    sys.error("GREATEST takes at least 2 parameters")
+  } else {
+    greatest((columnName +: columnNames).map(Column.apply): _*)
+  }
+
+  /**
+    * Computes hex value of the given column.
+    *
+    * @group math_funcs
+    * @since 1.5.0
+    */
   def hex(column: Column): Column = Hex(column.expr)
 
   /**
-   * Computes hex value of the given input
+   * Computes hex value of the given input.
    *
    * @group math_funcs
    * @since 1.5.0
@@ -1172,6 +1198,32 @@ object functions {
   def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName))
 
   /**
+   * Returns the least value of the list of values.
+   *
+   * @group normal_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def least(exprs: Column*): Column = if (exprs.length < 2) {
+    sys.error("LEAST takes at least 2 parameters")
+  } else {
+    Least(exprs.map(_.expr): _*)
+  }
+
+  /**
+   * Returns the least value of the list of column names.
+   *
+   * @group normal_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
+    sys.error("LEAST takes at least 2 parameters")
+  } else {
+    least((columnName +: columnNames).map(Column.apply): _*)
+  }
+
+  /**
    * Computes the natural logarithm of the given value.
    *
    * @group math_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/92540d22/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 1732803..6cebec9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -381,4 +381,26 @@ class DataFrameFunctionsSuite extends QueryTest {
       df.selectExpr("split(a, '[1-9]+')"),
       Row(Seq("aa", "bb", "cc")))
   }
+
+  test("conditional function: least") {
+    checkAnswer(
+      testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
+      Row(-1)
+    )
+    checkAnswer(
+      ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
+      Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
+    )
+  }
+
+  test("conditional function: greatest") {
+    checkAnswer(
+      testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1),
+      Row(3)
+    )
+    checkAnswer(
+      ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
+      Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org