You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/01 06:56:08 UTC
spark git commit: [SPARK-7248] implemented random number generators
for DataFrames
Repository: spark
Updated Branches:
refs/heads/master 69a739c7f -> b5347a466
[SPARK-7248] implemented random number generators for DataFrames
Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames.
cc mengxr rxin
Author: Burak Yavuz <br...@gmail.com>
Closes #5819 from brkyvz/df-rng and squashes the following commits:
50d69d4 [Burak Yavuz] add seed for test that failed
4234c3a [Burak Yavuz] fix Rand expression
13cad5c [Burak Yavuz] couple fixes
7d53953 [Burak Yavuz] waiting for hive tests
b453716 [Burak Yavuz] move radn with seed down
03637f0 [Burak Yavuz] fix broken hive func
c5909eb [Burak Yavuz] deleted old implementation of Rand
6d43895 [Burak Yavuz] implemented random generators
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5347a46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5347a46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5347a46
Branch: refs/heads/master
Commit: b5347a4664625ede6ab9d8ef6558457a34ae423f
Parents: 69a739c
Author: Burak Yavuz <br...@gmail.com>
Authored: Thu Apr 30 21:56:03 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Apr 30 21:56:03 2015 -0700
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 25 ++++++++-
python/pyspark/sql/tests.py | 10 ++++
.../spark/sql/catalyst/expressions/Rand.scala | 36 -------------
.../spark/sql/catalyst/expressions/random.scala | 56 ++++++++++++++++++++
.../optimizer/ConstantFoldingSuite.scala | 4 +-
.../scala/org/apache/spark/sql/functions.scala | 30 ++++++++++-
.../org/apache/spark/sql/mathfunctions.scala | 2 -
.../apache/spark/sql/JavaDataFrameSuite.java | 3 ++
.../spark/sql/ColumnExpressionSuite.scala | 22 ++++++++
.../org/apache/spark/sql/hive/HiveQl.scala | 7 ++-
10 files changed, 149 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 555c2fa..241f821 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -67,7 +67,6 @@ _functions = {
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}
-
for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
del _name, _doc
@@ -75,6 +74,30 @@ __all__ += _functions.keys()
__all__.sort()
+def rand(seed=None):
+ """
+ Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.rand(seed)
+ else:
+ jc = sc._jvm.functions.rand()
+ return Column(jc)
+
+
+def randn(seed=None):
+ """
+ Generate a column with i.i.d. samples from the standard normal distribution.
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.randn(seed)
+ else:
+ jc = sc._jvm.functions.randn()
+ return Column(jc)
+
+
def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2ffd18e..5640bb5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase):
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
+ def test_rand_functions(self):
+ df = self.df
+ from pyspark.sql import functions
+ rnd = df.select('key', functions.rand()).collect()
+ for row in rnd:
+ assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+ rndn = df.select('key', functions.randn(5)).collect()
+ for row in rndn:
+ assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
deleted file mode 100644
index f5fea3f..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import java.util.Random
-
-import org.apache.spark.sql.types.{DataType, DoubleType}
-
-
-case object Rand extends LeafExpression {
- override def dataType: DataType = DoubleType
- override def nullable: Boolean = false
-
- private[this] lazy val rand = new Random
-
- override def eval(input: Row = null): EvaluatedType = {
- rand.nextDouble().asInstanceOf[EvaluatedType]
- }
-
- override def toString: String = "RAND()"
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
new file mode 100644
index 0000000..66d7c8b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.types.{DataType, DoubleType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * A Random distribution generating expression.
+ * TODO: This can be made generic to generate any type of random distribution, or any type of
+ * StructType.
+ *
+ * Since this expression is stateful, it cannot be a case object.
+ */
+abstract class RDG(seed: Long) extends LeafExpression with Serializable {
+ self: Product =>
+
+ /**
+ * Record ID within each partition. By being transient, the Random Number Generator is
+ * reset every time we serialize and deserialize it.
+ */
+ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
+
+ override type EvaluatedType = Double
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = DoubleType
+}
+
+/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
+case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+ override def eval(input: Row): Double = rng.nextDouble()
+}
+
+/** Generate a random column with i.i.d. gaussian random distribution. */
+case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+ override def eval(input: Row): Double = rng.nextGaussian()
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 14b28e8..18f9215 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -160,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest {
val originalQuery =
testRelation
.select(
- Rand + Literal(1) as Symbol("c1"),
+ Rand(5L) + Literal(1) as Symbol("c1"),
Sum('a) as Symbol("c2"))
val optimized = Optimize.execute(originalQuery.analyze)
@@ -168,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest {
val correctAnswer =
testRelation
.select(
- Rand + Literal(1.0) as Symbol("c1"),
+ Rand(5L) + Literal(1.0) as Symbol("c1"),
Sum('a) as Symbol("c2"))
.analyze
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index aa31d04..242e64d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -347,6 +347,34 @@ object functions {
def not(e: Column): Column = !e
/**
+ * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ *
+ * @group normal_funcs
+ */
+ def rand(seed: Long): Column = Rand(seed)
+
+ /**
+ * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ *
+ * @group normal_funcs
+ */
+ def rand(): Column = rand(Utils.random.nextLong)
+
+ /**
+ * Generate a column with i.i.d. samples from the standard normal distribution.
+ *
+ * @group normal_funcs
+ */
+ def randn(seed: Long): Column = Randn(seed)
+
+ /**
+ * Generate a column with i.i.d. samples from the standard normal distribution.
+ *
+ * @group normal_funcs
+ */
+ def randn(): Column = randn(Utils.random.nextLong)
+
+ /**
* Partition ID of the Spark task.
*
* Note that this is indeterministic because it depends on data partitioning and task scheduling.
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
index d901542..db47480 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.functions.lit
/**
* :: Experimental ::
* Mathematical Functions available for [[DataFrame]].
- *
- * @groupname double_funcs Functions that require DoubleType as an input
*/
@Experimental
// scalastyle:off
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 966d879..ebe96e6 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -104,6 +104,9 @@ public class JavaDataFrameSuite {
df2.select(pow("a", "a"), pow("b", 2.0));
df2.select(pow(col("a"), col("b")), exp("b"));
df2.select(sin("a"), acos("b"));
+
+ df2.select(rand(), acos("b"));
+ df2.select(col("*"), randn(5L));
}
@Ignore
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 2ba5fc2..6322faf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.scalatest.Matchers._
+
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -349,4 +351,24 @@ class ColumnExpressionSuite extends QueryTest {
assert(schema("value").metadata === Metadata.empty)
assert(schema("abc").metadata === metadata)
}
+
+ test("rand") {
+ val randCol = testData.select('key, rand(5L).as("rand"))
+ randCol.columns.length should be (2)
+ val rows = randCol.collect()
+ rows.foreach { row =>
+ assert(row.getDouble(1) <= 1.0)
+ assert(row.getDouble(1) >= 0.0)
+ }
+ }
+
+ test("randn") {
+ val randCol = testData.select('key, randn(5L).as("rand"))
+ randCol.columns.length should be (2)
+ val rows = randCol.collect()
+ rows.foreach { row =>
+ assert(row.getDouble(1) <= 4.0)
+ assert(row.getDouble(1) >= -4.0)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 0a86519..63a8c05 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.hive
import java.sql.Date
-
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
-
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse._
@@ -1244,7 +1242,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* Other functions */
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
CreateArray(children.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org