You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/03 05:37:34 UTC

spark git commit: [SPARK-8213][SQL]Add function factorial

Repository: spark
Updated Branches:
  refs/heads/master aa7bbc143 -> 1a7a7d7d5


[SPARK-8213][SQL]Add function factorial

Author: zhichao.li <zh...@intel.com>

Closes #6822 from zhichao-li/factorial and squashes the following commits:

26edf4f [zhichao.li] add factorial


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a7a7d7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a7a7d7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a7a7d7d

Branch: refs/heads/master
Commit: 1a7a7d7d579c5cba104daffbda977915802bf9b9
Parents: aa7bbc1
Author: zhichao.li <zh...@intel.com>
Authored: Thu Jul 2 20:37:31 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 2 20:37:31 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../spark/sql/catalyst/expressions/math.scala   | 80 +++++++++++++++++++-
 .../expressions/MathFunctionsSuite.scala        | 15 +++-
 .../scala/org/apache/spark/sql/functions.scala  | 16 ++++
 .../apache/spark/sql/MathExpressionsSuite.scala | 13 +++-
 5 files changed, 122 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1a7a7d7d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e7e4d1c..9163b03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -113,6 +113,7 @@ object FunctionRegistry {
     expression[Exp]("exp"),
     expression[Expm1]("expm1"),
     expression[Floor]("floor"),
+    expression[Factorial]("factorial"),
     expression[Hypot]("hypot"),
     expression[Hex]("hex"),
     expression[Logarithm]("log"),

http://git-wip-us.apache.org/repos/asf/spark/blob/1a7a7d7d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 035980d..701ab99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -21,8 +21,10 @@ import java.lang.{Long => JLong}
 import java.util.Arrays
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{StringType}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -159,6 +161,82 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP
 
 case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
 
+object Factorial {
+
+  def factorial(n: Int): Long = {
+    if (n < factorials.length) factorials(n) else Long.MaxValue
+  }
+
+  private val factorials: Array[Long] = Array[Long](
+    1,
+    1,
+    2,
+    6,
+    24,
+    120,
+    720,
+    5040,
+    40320,
+    362880,
+    3628800,
+    39916800,
+    479001600,
+    6227020800L,
+    87178291200L,
+    1307674368000L,
+    20922789888000L,
+    355687428096000L,
+    6402373705728000L,
+    121645100408832000L,
+    2432902008176640000L
+  )
+}
+
+case class Factorial(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[DataType] = Seq(IntegerType)
+
+  override def dataType: DataType = LongType
+
+  override def foldable: Boolean = child.foldable
+
+  // If the value not in the range of [0, 20], it still will be null, so set it to be true here.
+  override def nullable: Boolean = true
+
+  override def toString: String = s"factorial($child)"
+
+  override def eval(input: InternalRow): Any = {
+    val evalE = child.eval(input)
+    if (evalE == null) {
+      null
+    } else {
+      val input = evalE.asInstanceOf[Integer]
+      if (input > 20 || input < 0) {
+        null
+      } else {
+        Factorial.factorial(input)
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val eval = child.gen(ctx)
+    eval.code + s"""
+      boolean ${ev.isNull} = ${eval.isNull};
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        if (${eval.primitive} > 20 || ${eval.primitive} < 0) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.primitive} =
+            org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive});
+        }
+      }
+    """
+  }
+}
+
 case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
 
 case class Log2(child: Expression)

http://git-wip-us.apache.org/repos/asf/spark/blob/1a7a7d7d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index aa27fe3..8457864 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import com.google.common.math.LongMath
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
+import org.apache.spark.sql.types.{DataType, LongType}
+import org.apache.spark.sql.types.{IntegerType, DoubleType}
 
 class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -157,6 +160,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     testUnary(Floor, math.floor)
   }
 
+  test("factorial") {
+    val dataLong = (0 to 20)
+    dataLong.foreach { value =>
+      checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
+    }
+    checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null))
+    checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
+    checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
+  }
+
   test("rint") {
     testUnary(Rint, math.rint)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1a7a7d7d/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4ee1fb8..0d5d49c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1023,6 +1023,22 @@ object functions {
   def expm1(columnName: String): Column = expm1(Column(columnName))
 
   /**
+   * Computes the factorial of the given value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def factorial(e: Column): Column = Factorial(e.expr)
+
+  /**
+   * Computes the factorial of the given column.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def factorial(columnName: String): Column = factorial(Column(columnName))
+
+  /**
    * Computes the floor of the given value.
    *
    * @group math_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/1a7a7d7d/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 4c5696d..dc8f994 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.functions.{log => logarithm}
 
-
 private object MathExpressionsTestData {
   case class DoubleData(a: java.lang.Double, b: java.lang.Double)
   case class NullDoubles(a: java.lang.Double)
@@ -183,6 +182,18 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneMathFunction(floor, math.floor)
   }
 
+  test("factorial") {
+    val df = (0 to 5).map(i => (i, i)).toDF("a", "b")
+    checkAnswer(
+      df.select(factorial('a)),
+      Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
+    )
+    checkAnswer(
+      df.selectExpr("factorial(a)"),
+      Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
+    )
+  }
+
   test("rint") {
     testOneToOneMathFunction(rint, math.rint)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org