You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/11/06 14:11:43 UTC

spark git commit: [SPARK-17854][SQL] rand/randn allows null/long as input seed

Repository: spark
Updated Branches:
  refs/heads/master 23ce0d1e9 -> 340f09d10


[SPARK-17854][SQL] rand/randn allows null/long as input seed

## What changes were proposed in this pull request?

This PR proposes `rand`/`randn` accept `null` as input in Scala/SQL and `LongType` as input in SQL. In this case, it treats the values as `0`.

So, this PR includes both changes below:
- `null` support

  It seems MySQL also accepts this.

  ``` sql
  mysql> select rand(0);
  +---------------------+
  | rand(0)             |
  +---------------------+
  | 0.15522042769493574 |
  +---------------------+
  1 row in set (0.00 sec)

  mysql> select rand(NULL);
  +---------------------+
  | rand(NULL)          |
  +---------------------+
  | 0.15522042769493574 |
  +---------------------+
  1 row in set (0.00 sec)
  ```

  and also Hive does according to [HIVE-14694](https://issues.apache.org/jira/browse/HIVE-14694)

  So the codes below:

  ``` scala
  spark.range(1).selectExpr("rand(null)").show()
  ```

  prints..

  **Before**

  ```
    Input argument to rand must be an integer literal.;; line 1 pos 0
  org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0
  at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465)
  at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:444)
  ```

  **After**

  ```
    +-----------------------+
    |rand(CAST(NULL AS INT))|
    +-----------------------+
    |    0.13385709732307427|
    +-----------------------+
  ```
- `LongType` support in SQL.

  In addition, it make the function allows to take `LongType` consistently within Scala/SQL.

  In more details, the codes below:

  ``` scala
  spark.range(1).select(rand(1), rand(1L)).show()
  spark.range(1).selectExpr("rand(1)", "rand(1L)").show()
  ```

  prints..

  **Before**

  ```
  +------------------+------------------+
  |           rand(1)|           rand(1)|
  +------------------+------------------+
  |0.2630967864682161|0.2630967864682161|
  +------------------+------------------+

  Input argument to rand must be an integer literal.;; line 1 pos 0
  org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0
  at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465)
  at
  ```

  **After**

  ```
  +------------------+------------------+
  |           rand(1)|           rand(1)|
  +------------------+------------------+
  |0.2630967864682161|0.2630967864682161|
  +------------------+------------------+

  +------------------+------------------+
  |           rand(1)|           rand(1)|
  +------------------+------------------+
  |0.2630967864682161|0.2630967864682161|
  +------------------+------------------+
  ```
## How was this patch tested?

Unit tests in `DataFrameSuite.scala` and `RandomSuite.scala`.

Author: hyukjinkwon <gu...@gmail.com>

Closes #15432 from HyukjinKwon/SPARK-17854.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/340f09d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/340f09d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/340f09d1

Branch: refs/heads/master
Commit: 340f09d100cb669bc6795f085aac6fa05630a076
Parents: 23ce0d1
Author: hyukjinkwon <gu...@gmail.com>
Authored: Sun Nov 6 14:11:37 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Nov 6 14:11:37 2016 +0000

----------------------------------------------------------------------
 .../expressions/randomExpressions.scala         | 50 +++++++-----
 .../sql/catalyst/expressions/RandomSuite.scala  |  6 ++
 .../test/resources/sql-tests/inputs/random.sql  | 17 ++++
 .../resources/sql-tests/results/random.sql.out  | 84 ++++++++++++++++++++
 4 files changed, 135 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index a331a55..1d7a3c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.types.{DataType, DoubleType}
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
 
@@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
  *
  * Since this expression is stateful, it cannot be a case object.
  */
-abstract class RDG extends LeafExpression with Nondeterministic {
-
-  protected def seed: Long
-
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
   /**
    * Record ID within each partition. By being transient, the Random Number Generator is
    * reset every time we serialize and deserialize and initialize it.
@@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic {
     rng = new XORShiftRandom(seed + partitionIndex)
   }
 
+  @transient protected lazy val seed: Long = child match {
+    case Literal(s, IntegerType) => s.asInstanceOf[Int]
+    case Literal(s, LongType) => s.asInstanceOf[Long]
+    case _ => throw new AnalysisException(
+      s"Input argument to $prettyName must be an integer, long or null literal.")
+  }
+
   override def nullable: Boolean = false
 
   override def dataType: DataType = DoubleType
 
-  // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
-  override def sql: String = s"$prettyName($seed)"
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
 }
 
 /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
@@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic {
        0.9629742951434543
       > SELECT _FUNC_(0);
        0.8446490682263027
+      > SELECT _FUNC_(null);
+       0.8446490682263027
   """)
 // scalastyle:on line.size.limit
-case class Rand(seed: Long) extends RDG {
-  override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
+case class Rand(child: Expression) extends RDG {
 
-  def this() = this(Utils.random.nextLong())
+  def this() = this(Literal(Utils.random.nextLong(), LongType))
 
-  def this(seed: Expression) = this(seed match {
-    case IntegerLiteral(s) => s
-    case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
-  })
+  override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
@@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG {
   }
 }
 
+object Rand {
+  def apply(seed: Long): Rand = Rand(Literal(seed, LongType))
+}
+
 /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
@@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG {
        -0.3254147983080288
       > SELECT _FUNC_(0);
        1.1164209726833079
+      > SELECT _FUNC_(null);
+       1.1164209726833079
   """)
 // scalastyle:on line.size.limit
-case class Randn(seed: Long) extends RDG {
-  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
+case class Randn(child: Expression) extends RDG {
 
-  def this() = this(Utils.random.nextLong())
+  def this() = this(Literal(Utils.random.nextLong(), LongType))
 
-  def this(seed: Expression) = this(seed match {
-    case IntegerLiteral(s) => s
-    case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
-  })
+  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
@@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG {
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
   }
 }
+
+object Randn {
+  def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
index b7a0d44..752c9d5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
@@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions
 import org.scalatest.Matchers._
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{IntegerType, LongType}
 
 class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("random") {
     checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
     checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)
+
+    checkDoubleEvaluation(
+      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
+    checkDoubleEvaluation(
+      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
   }
 
   test("SPARK-9127 codegen with long seed") {

http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/core/src/test/resources/sql-tests/inputs/random.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql
new file mode 100644
index 0000000..a1aae7b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql
@@ -0,0 +1,17 @@
+-- rand with the seed 0
+SELECT rand(0);
+SELECT rand(cast(3 / 7 AS int));
+SELECT rand(NULL);
+SELECT rand(cast(NULL AS int));
+
+-- rand unsupported data type
+SELECT rand(1.0);
+
+-- randn with the seed 0
+SELECT randn(0L);
+SELECT randn(cast(3 / 7 AS long));
+SELECT randn(NULL);
+SELECT randn(cast(NULL AS long));
+
+-- randn unsupported data type
+SELECT rand('1')

http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/core/src/test/resources/sql-tests/results/random.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out
new file mode 100644
index 0000000..bca6732
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out
@@ -0,0 +1,84 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query 0
+SELECT rand(0)
+-- !query 0 schema
+struct<rand(0):double>
+-- !query 0 output
+0.8446490682263027
+
+
+-- !query 1
+SELECT rand(cast(3 / 7 AS int))
+-- !query 1 schema
+struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double>
+-- !query 1 output
+0.8446490682263027
+
+
+-- !query 2
+SELECT rand(NULL)
+-- !query 2 schema
+struct<rand(CAST(NULL AS INT)):double>
+-- !query 2 output
+0.8446490682263027
+
+
+-- !query 3
+SELECT rand(cast(NULL AS int))
+-- !query 3 schema
+struct<rand(CAST(NULL AS INT)):double>
+-- !query 3 output
+0.8446490682263027
+
+
+-- !query 4
+SELECT rand(1.0)
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7
+
+
+-- !query 5
+SELECT randn(0L)
+-- !query 5 schema
+struct<randn(0):double>
+-- !query 5 output
+1.1164209726833079
+
+
+-- !query 6
+SELECT randn(cast(3 / 7 AS long))
+-- !query 6 schema
+struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double>
+-- !query 6 output
+1.1164209726833079
+
+
+-- !query 7
+SELECT randn(NULL)
+-- !query 7 schema
+struct<randn(CAST(NULL AS INT)):double>
+-- !query 7 output
+1.1164209726833079
+
+
+-- !query 8
+SELECT randn(cast(NULL AS long))
+-- !query 8 schema
+struct<randn(CAST(NULL AS BIGINT)):double>
+-- !query 8 output
+1.1164209726833079
+
+
+-- !query 9
+SELECT rand('1')
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org