You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/11/06 14:11:43 UTC
spark git commit: [SPARK-17854][SQL] rand/randn allows null/long as
input seed
Repository: spark
Updated Branches:
refs/heads/master 23ce0d1e9 -> 340f09d10
[SPARK-17854][SQL] rand/randn allows null/long as input seed
## What changes were proposed in this pull request?
This PR proposes `rand`/`randn` accept `null` as input in Scala/SQL and `LongType` as input in SQL. In this case, it treats the values as `0`.
So, this PR includes both changes below:
- `null` support
It seems MySQL also accepts this.
``` sql
mysql> select rand(0);
+---------------------+
| rand(0) |
+---------------------+
| 0.15522042769493574 |
+---------------------+
1 row in set (0.00 sec)
mysql> select rand(NULL);
+---------------------+
| rand(NULL) |
+---------------------+
| 0.15522042769493574 |
+---------------------+
1 row in set (0.00 sec)
```
and also Hive does according to [HIVE-14694](https://issues.apache.org/jira/browse/HIVE-14694)
So the codes below:
``` scala
spark.range(1).selectExpr("rand(null)").show()
```
prints..
**Before**
```
Input argument to rand must be an integer literal.;; line 1 pos 0
org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0
at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465)
at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:444)
```
**After**
```
+-----------------------+
|rand(CAST(NULL AS INT))|
+-----------------------+
| 0.13385709732307427|
+-----------------------+
```
- `LongType` support in SQL.
In addition, it make the function allows to take `LongType` consistently within Scala/SQL.
In more details, the codes below:
``` scala
spark.range(1).select(rand(1), rand(1L)).show()
spark.range(1).selectExpr("rand(1)", "rand(1L)").show()
```
prints..
**Before**
```
+------------------+------------------+
| rand(1)| rand(1)|
+------------------+------------------+
|0.2630967864682161|0.2630967864682161|
+------------------+------------------+
Input argument to rand must be an integer literal.;; line 1 pos 0
org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0
at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465)
at
```
**After**
```
+------------------+------------------+
| rand(1)| rand(1)|
+------------------+------------------+
|0.2630967864682161|0.2630967864682161|
+------------------+------------------+
+------------------+------------------+
| rand(1)| rand(1)|
+------------------+------------------+
|0.2630967864682161|0.2630967864682161|
+------------------+------------------+
```
## How was this patch tested?
Unit tests in `DataFrameSuite.scala` and `RandomSuite.scala`.
Author: hyukjinkwon <gu...@gmail.com>
Closes #15432 from HyukjinKwon/SPARK-17854.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/340f09d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/340f09d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/340f09d1
Branch: refs/heads/master
Commit: 340f09d100cb669bc6795f085aac6fa05630a076
Parents: 23ce0d1
Author: hyukjinkwon <gu...@gmail.com>
Authored: Sun Nov 6 14:11:37 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Nov 6 14:11:37 2016 +0000
----------------------------------------------------------------------
.../expressions/randomExpressions.scala | 50 +++++++-----
.../sql/catalyst/expressions/RandomSuite.scala | 6 ++
.../test/resources/sql-tests/inputs/random.sql | 17 ++++
.../resources/sql-tests/results/random.sql.out | 84 ++++++++++++++++++++
4 files changed, 135 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index a331a55..1d7a3c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.types.{DataType, DoubleType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
-abstract class RDG extends LeafExpression with Nondeterministic {
-
- protected def seed: Long
-
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
@@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic {
rng = new XORShiftRandom(seed + partitionIndex)
}
+ @transient protected lazy val seed: Long = child match {
+ case Literal(s, IntegerType) => s.asInstanceOf[Int]
+ case Literal(s, LongType) => s.asInstanceOf[Long]
+ case _ => throw new AnalysisException(
+ s"Input argument to $prettyName must be an integer, long or null literal.")
+ }
+
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
- // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
- override def sql: String = s"$prettyName($seed)"
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
@@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic {
0.9629742951434543
> SELECT _FUNC_(0);
0.8446490682263027
+ > SELECT _FUNC_(null);
+ 0.8446490682263027
""")
// scalastyle:on line.size.limit
-case class Rand(seed: Long) extends RDG {
- override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
+case class Rand(child: Expression) extends RDG {
- def this() = this(Utils.random.nextLong())
+ def this() = this(Literal(Utils.random.nextLong(), LongType))
- def this(seed: Expression) = this(seed match {
- case IntegerLiteral(s) => s
- case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
- })
+ override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
@@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG {
}
}
+object Rand {
+ def apply(seed: Long): Rand = Rand(Literal(seed, LongType))
+}
+
/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG {
-0.3254147983080288
> SELECT _FUNC_(0);
1.1164209726833079
+ > SELECT _FUNC_(null);
+ 1.1164209726833079
""")
// scalastyle:on line.size.limit
-case class Randn(seed: Long) extends RDG {
- override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
+case class Randn(child: Expression) extends RDG {
- def this() = this(Utils.random.nextLong())
+ def this() = this(Literal(Utils.random.nextLong(), LongType))
- def this(seed: Expression) = this(seed match {
- case IntegerLiteral(s) => s
- case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
- })
+ override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
@@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG {
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
}
}
+
+object Randn {
+ def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
index b7a0d44..752c9d5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
@@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions
import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{IntegerType, LongType}
class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
test("random") {
checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)
+
+ checkDoubleEvaluation(
+ new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
+ checkDoubleEvaluation(
+ new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
}
test("SPARK-9127 codegen with long seed") {
http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/core/src/test/resources/sql-tests/inputs/random.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql
new file mode 100644
index 0000000..a1aae7b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql
@@ -0,0 +1,17 @@
+-- rand with the seed 0
+SELECT rand(0);
+SELECT rand(cast(3 / 7 AS int));
+SELECT rand(NULL);
+SELECT rand(cast(NULL AS int));
+
+-- rand unsupported data type
+SELECT rand(1.0);
+
+-- randn with the seed 0
+SELECT randn(0L);
+SELECT randn(cast(3 / 7 AS long));
+SELECT randn(NULL);
+SELECT randn(cast(NULL AS long));
+
+-- randn unsupported data type
+SELECT rand('1')
http://git-wip-us.apache.org/repos/asf/spark/blob/340f09d1/sql/core/src/test/resources/sql-tests/results/random.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out
new file mode 100644
index 0000000..bca6732
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out
@@ -0,0 +1,84 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query 0
+SELECT rand(0)
+-- !query 0 schema
+struct<rand(0):double>
+-- !query 0 output
+0.8446490682263027
+
+
+-- !query 1
+SELECT rand(cast(3 / 7 AS int))
+-- !query 1 schema
+struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double>
+-- !query 1 output
+0.8446490682263027
+
+
+-- !query 2
+SELECT rand(NULL)
+-- !query 2 schema
+struct<rand(CAST(NULL AS INT)):double>
+-- !query 2 output
+0.8446490682263027
+
+
+-- !query 3
+SELECT rand(cast(NULL AS int))
+-- !query 3 schema
+struct<rand(CAST(NULL AS INT)):double>
+-- !query 3 output
+0.8446490682263027
+
+
+-- !query 4
+SELECT rand(1.0)
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7
+
+
+-- !query 5
+SELECT randn(0L)
+-- !query 5 schema
+struct<randn(0):double>
+-- !query 5 output
+1.1164209726833079
+
+
+-- !query 6
+SELECT randn(cast(3 / 7 AS long))
+-- !query 6 schema
+struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double>
+-- !query 6 output
+1.1164209726833079
+
+
+-- !query 7
+SELECT randn(NULL)
+-- !query 7 schema
+struct<randn(CAST(NULL AS INT)):double>
+-- !query 7 output
+1.1164209726833079
+
+
+-- !query 8
+SELECT randn(cast(NULL AS long))
+-- !query 8 schema
+struct<randn(CAST(NULL AS BIGINT)):double>
+-- !query 8 output
+1.1164209726833079
+
+
+-- !query 9
+SELECT rand('1')
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org