You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/01 06:56:08 UTC

spark git commit: [SPARK-7248] implemented random number generators for DataFrames

Repository: spark
Updated Branches:
  refs/heads/master 69a739c7f -> b5347a466


 [SPARK-7248] implemented random number generators for DataFrames

Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames.

cc mengxr rxin

Author: Burak Yavuz <br...@gmail.com>

Closes #5819 from brkyvz/df-rng and squashes the following commits:

50d69d4 [Burak Yavuz] add seed for test that failed
4234c3a [Burak Yavuz] fix Rand expression
13cad5c [Burak Yavuz] couple fixes
7d53953 [Burak Yavuz] waiting for hive tests
b453716 [Burak Yavuz] move radn with seed down
03637f0 [Burak Yavuz] fix broken hive func
c5909eb [Burak Yavuz] deleted old implementation of Rand
6d43895 [Burak Yavuz] implemented random generators


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5347a46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5347a46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5347a46

Branch: refs/heads/master
Commit: b5347a4664625ede6ab9d8ef6558457a34ae423f
Parents: 69a739c
Author: Burak Yavuz <br...@gmail.com>
Authored: Thu Apr 30 21:56:03 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Apr 30 21:56:03 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 25 ++++++++-
 python/pyspark/sql/tests.py                     | 10 ++++
 .../spark/sql/catalyst/expressions/Rand.scala   | 36 -------------
 .../spark/sql/catalyst/expressions/random.scala | 56 ++++++++++++++++++++
 .../optimizer/ConstantFoldingSuite.scala        |  4 +-
 .../scala/org/apache/spark/sql/functions.scala  | 30 ++++++++++-
 .../org/apache/spark/sql/mathfunctions.scala    |  2 -
 .../apache/spark/sql/JavaDataFrameSuite.java    |  3 ++
 .../spark/sql/ColumnExpressionSuite.scala       | 22 ++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |  7 ++-
 10 files changed, 149 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 555c2fa..241f821 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -67,7 +67,6 @@ _functions = {
     'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
 }
 
-
 for _name, _doc in _functions.items():
     globals()[_name] = _create_function(_name, _doc)
 del _name, _doc
@@ -75,6 +74,30 @@ __all__ += _functions.keys()
 __all__.sort()
 
 
+def rand(seed=None):
+    """
+    Generate a random column with i.i.d. samples from U[0.0, 1.0].
+    """
+    sc = SparkContext._active_spark_context
+    if seed:
+        jc = sc._jvm.functions.rand(seed)
+    else:
+        jc = sc._jvm.functions.rand()
+    return Column(jc)
+
+
+def randn(seed=None):
+    """
+    Generate a column with i.i.d. samples from the standard normal distribution.
+    """
+    sc = SparkContext._active_spark_context
+    if seed:
+        jc = sc._jvm.functions.randn(seed)
+    else:
+        jc = sc._jvm.functions.randn()
+    return Column(jc)
+
+
 def approxCountDistinct(col, rsd=None):
     """Returns a new :class:`Column` for approximate distinct count of ``col``.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2ffd18e..5640bb5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase):
         assert_close([math.hypot(i, 2 * i) for i in range(10)],
                      df.select(functions.hypot(df.a, df.b)).collect())
 
+    def test_rand_functions(self):
+        df = self.df
+        from pyspark.sql import functions
+        rnd = df.select('key', functions.rand()).collect()
+        for row in rnd:
+            assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+        rndn = df.select('key', functions.randn(5)).collect()
+        for row in rndn:
+            assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
     def test_save_and_load(self):
         df = self.df
         tmpPath = tempfile.mkdtemp()

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
deleted file mode 100644
index f5fea3f..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import java.util.Random
-
-import org.apache.spark.sql.types.{DataType, DoubleType}
-
-
-case object Rand extends LeafExpression {
-  override def dataType: DataType = DoubleType
-  override def nullable: Boolean = false
-
-  private[this] lazy val rand = new Random
-
-  override def eval(input: Row = null): EvaluatedType = {
-    rand.nextDouble().asInstanceOf[EvaluatedType]
-  }
-
-  override def toString: String = "RAND()"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
new file mode 100644
index 0000000..66d7c8b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.types.{DataType, DoubleType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * A Random distribution generating expression.
+ * TODO: This can be made generic to generate any type of random distribution, or any type of  
+ * StructType.
+ *
+ * Since this expression is stateful, it cannot be a case object.
+ */
+abstract class RDG(seed: Long) extends LeafExpression with Serializable {
+  self: Product =>
+
+  /**
+   * Record ID within each partition. By being transient, the Random Number Generator is
+   * reset every time we serialize and deserialize it.
+   */
+  @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
+
+  override type EvaluatedType = Double
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = DoubleType
+}
+
+/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
+case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+  override def eval(input: Row): Double = rng.nextDouble()
+}
+
+/** Generate a random column with i.i.d. gaussian random distribution. */
+case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+  override def eval(input: Row): Double = rng.nextGaussian()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 14b28e8..18f9215 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -160,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest {
     val originalQuery =
       testRelation
         .select(
-          Rand + Literal(1) as Symbol("c1"),
+          Rand(5L) + Literal(1) as Symbol("c1"),
           Sum('a) as Symbol("c2"))
 
     val optimized = Optimize.execute(originalQuery.analyze)
@@ -168,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest {
     val correctAnswer =
       testRelation
         .select(
-          Rand + Literal(1.0) as Symbol("c1"),
+          Rand(5L) + Literal(1.0) as Symbol("c1"),
           Sum('a) as Symbol("c2"))
         .analyze
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index aa31d04..242e64d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
-
+import org.apache.spark.util.Utils
 
 /**
  * :: Experimental ::
@@ -347,6 +347,34 @@ object functions {
   def not(e: Column): Column = !e
 
   /**
+   * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+   *
+   * @group normal_funcs
+   */
+  def rand(seed: Long): Column = Rand(seed)
+
+  /**
+   * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+   *
+   * @group normal_funcs
+   */
+  def rand(): Column = rand(Utils.random.nextLong)
+
+  /**
+   * Generate a column with i.i.d. samples from the standard normal distribution.
+   *
+   * @group normal_funcs
+   */
+  def randn(seed: Long): Column = Randn(seed)
+
+  /**
+   * Generate a column with i.i.d. samples from the standard normal distribution.
+   *
+   * @group normal_funcs
+   */
+  def randn(): Column = randn(Utils.random.nextLong)
+
+  /**
    * Partition ID of the Spark task.
    *
    * Note that this is indeterministic because it depends on data partitioning and task scheduling.

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
index d901542..db47480 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.functions.lit
 /**
  * :: Experimental ::
  * Mathematical Functions available for [[DataFrame]].
- *
- * @groupname double_funcs Functions that require DoubleType as an input
  */
 @Experimental
 // scalastyle:off

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 966d879..ebe96e6 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -104,6 +104,9 @@ public class JavaDataFrameSuite {
     df2.select(pow("a", "a"), pow("b", 2.0));
     df2.select(pow(col("a"), col("b")), exp("b"));
     df2.select(sin("a"), acos("b"));
+
+    df2.select(rand(), acos("b"));
+    df2.select(col("*"), randn(5L));
   }
 
   @Ignore

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 2ba5fc2..6322faf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import org.scalatest.Matchers._
+
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -349,4 +351,24 @@ class ColumnExpressionSuite extends QueryTest {
     assert(schema("value").metadata === Metadata.empty)
     assert(schema("abc").metadata === metadata)
   }
+
+  test("rand") {
+    val randCol = testData.select('key, rand(5L).as("rand"))
+    randCol.columns.length should be (2)
+    val rows = randCol.collect()
+    rows.foreach { row =>
+      assert(row.getDouble(1) <= 1.0)
+      assert(row.getDouble(1) >= 0.0)
+    }
+  }
+
+  test("randn") {
+    val randCol = testData.select('key, randn(5L).as("rand"))
+    randCol.columns.length should be (2)
+    val rows = randCol.collect()
+    rows.foreach { row =>
+      assert(row.getDouble(1) <= 4.0)
+      assert(row.getDouble(1) >= -4.0)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b5347a46/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 0a86519..63a8c05 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.hive
 
 import java.sql.Date
 
-
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
-
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.hadoop.hive.conf.HiveConf
 import org.apache.hadoop.hive.ql.Context
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
 import org.apache.hadoop.hive.ql.lib.Node
 import org.apache.hadoop.hive.ql.metadata.Table
 import org.apache.hadoop.hive.ql.parse._
@@ -1244,7 +1242,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
     /* Other functions */
     case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
       CreateArray(children.map(nodeToExpr))
-    case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+    case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
+    case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
     case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
       Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
     case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org