You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/19 06:43:21 UTC
spark git commit: [SPARK-7150] SparkContext.range() and
SQLContext.range()
Repository: spark
Updated Branches:
refs/heads/master d03638cc2 -> c2437de18
[SPARK-7150] SparkContext.range() and SQLContext.range()
This PR is based on #6081, thanks adrian-wang.
Closes #6081
Author: Daoyuan Wang <da...@intel.com>
Author: Davies Liu <da...@databricks.com>
Closes #6230 from davies/range and squashes the following commits:
d3ce5fe [Davies Liu] add tests
789eda5 [Davies Liu] add range() in Python
4590208 [Davies Liu] Merge commit 'refs/pull/6081/head' of github.com:apache/spark into range
cbf5200 [Daoyuan Wang] let's add python support in a separate PR
f45e3b2 [Daoyuan Wang] remove redundant toLong
617da76 [Daoyuan Wang] fix safe marge for corner cases
867c417 [Daoyuan Wang] fix
13dbe84 [Daoyuan Wang] update
bd998ba [Daoyuan Wang] update comments
d3a0c1b [Daoyuan Wang] add range api()
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2437de1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2437de1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2437de1
Branch: refs/heads/master
Commit: c2437de1899e09894df4ec27adfaa7fac158fd3a
Parents: d03638c
Author: Daoyuan Wang <da...@intel.com>
Authored: Mon May 18 21:43:12 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon May 18 21:43:12 2015 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/SparkContext.scala | 72 ++++++++++++++++++++
python/pyspark/context.py | 16 +++++
python/pyspark/sql/context.py | 20 ++++++
python/pyspark/sql/tests.py | 5 ++
python/pyspark/tests.py | 5 ++
.../scala/org/apache/spark/sql/SQLContext.scala | 31 +++++++++
.../org/apache/spark/sql/DataFrameSuite.scala | 40 +++++++++++
7 files changed, 189 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f78fbaf..3fe3dc5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -697,6 +697,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
+ /**
+ * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by
+ * `step` every element.
+ *
+ * @note if we need to cache this RDD, we should make sure each partition does not exceed limit.
+ *
+ * @param start the start value.
+ * @param end the end value.
+ * @param step the incremental step
+ * @param numSlices the partition number of the new RDD.
+ * @return
+ */
+ def range(
+ start: Long,
+ end: Long,
+ step: Long = 1,
+ numSlices: Int = defaultParallelism): RDD[Long] = withScope {
+ assertNotStopped()
+ // when step is 0, range will run infinitely
+ require(step != 0, "step cannot be 0")
+ val numElements: BigInt = {
+ val safeStart = BigInt(start)
+ val safeEnd = BigInt(end)
+ if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) {
+ (safeEnd - safeStart) / step
+ } else {
+ // the remainder has the same sign with range, could add 1 more
+ (safeEnd - safeStart) / step + 1
+ }
+ }
+ parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => {
+ val partitionStart = (i * numElements) / numSlices * step + start
+ val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
+ def getSafeMargin(bi: BigInt): Long =
+ if (bi.isValidLong) {
+ bi.toLong
+ } else if (bi > 0) {
+ Long.MaxValue
+ } else {
+ Long.MinValue
+ }
+ val safePartitionStart = getSafeMargin(partitionStart)
+ val safePartitionEnd = getSafeMargin(partitionEnd)
+
+ new Iterator[Long] {
+ private[this] var number: Long = safePartitionStart
+ private[this] var overflow: Boolean = false
+
+ override def hasNext =
+ if (!overflow) {
+ if (step > 0) {
+ number < safePartitionEnd
+ } else {
+ number > safePartitionEnd
+ }
+ } else false
+
+ override def next() = {
+ val ret = number
+ number += step
+ if (number < ret ^ step < 0) {
+ // we have Long.MaxValue + Long.MaxValue < Long.MaxValue
+ // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
+ // back, we are pretty sure that we have an overflow.
+ overflow = true
+ }
+ ret
+ }
+ }
+ })
+ }
+
/** Distribute a local Scala collection to form an RDD.
*
* This method is identical to `parallelize`.
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d25ee85..1f2b40b 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -319,6 +319,22 @@ class SparkContext(object):
with SparkContext._lock:
SparkContext._active_spark_context = None
+ def range(self, start, end, step=1, numSlices=None):
+ """
+ Create a new RDD of int containing elements from `start` to `end`
+ (exclusive), increased by `step` every element.
+
+ :param start: the start value
+ :param end: the end value (exclusive)
+ :param step: the incremental step (default: 1)
+ :param numSlices: the number of partitions of the new RDD
+ :return: An RDD of int
+
+ >>> sc.range(1, 7, 2).collect()
+ [1, 3, 5]
+ """
+ return self.parallelize(xrange(start, end, step), numSlices)
+
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD. Using xrange
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 0bde719..9f26d13 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -122,6 +122,26 @@ class SQLContext(object):
"""Returns a :class:`UDFRegistration` for UDF registration."""
return UDFRegistration(self)
+ def range(self, start, end, step=1, numPartitions=None):
+ """
+ Create a :class:`DataFrame` with single LongType column named `id`,
+ containing elements in a range from `start` to `end` (exclusive) with
+ step value `step`.
+
+ :param start: the start value
+ :param end: the end value (exclusive)
+ :param step: the incremental step (default: 1)
+ :param numPartitions: the number of partitions of the DataFrame
+ :return: A new DataFrame
+
+ >>> sqlContext.range(1, 7, 2).collect()
+ [Row(id=1), Row(id=3), Row(id=5)]
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+ return DataFrame(jdf, self)
+
@ignore_unicode_prefix
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d37c5db..84ae36f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,11 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_range(self):
+ self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
+ self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+
def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 5e023f6..d8e3199 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -444,6 +444,11 @@ class AddFileTests(PySparkTestCase):
class RDDTests(ReusedPySparkTestCase):
+ def test_range(self):
+ self.assertEqual(self.sc.range(1, 1).count(), 0)
+ self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
+
def test_id(self):
rdd = self.sc.parallelize(range(10))
id = rdd.id()
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ac1a800..316ef7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -685,6 +685,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
/**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end`(exclusive) with step value 1.
+ *
+ * @since 1.4.0
+ * @group dataframe
+ */
+ @Experimental
+ def range(start: Long, end: Long): DataFrame = {
+ createDataFrame(
+ sparkContext.range(start, end).map(Row(_)),
+ StructType(StructField("id", LongType, nullable = false) :: Nil))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end`(exclusive) with an step value, with partition number
+ * specified.
+ *
+ * @since 1.4.0
+ * @group dataframe
+ */
+ @Experimental
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
+ createDataFrame(
+ sparkContext.range(start, end, step, numPartitions).map(Row(_)),
+ StructType(StructField("id", LongType, nullable = false) :: Nil))
+ }
+
+ /**
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/c2437de1/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 054b23d..f05d059 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest {
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
assert(!p.child.isInstanceOf[Project])
}
+
+ test("SPARK-7150 range api") {
+ // numSlice is greater than length
+ val res1 = TestSQLContext.range(0, 10, 1, 15).select("id")
+ assert(res1.count == 10)
+ assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res2 = TestSQLContext.range(3, 15, 3, 2).select("id")
+ assert(res2.count == 4)
+ assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ val res3 = TestSQLContext.range(1, -2).select("id")
+ assert(res3.count == 0)
+
+ // start is positive, end is negative, step is negative
+ val res4 = TestSQLContext.range(1, -2, -2, 6).select("id")
+ assert(res4.count == 2)
+ assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+ // start, end, step are negative
+ val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id")
+ assert(res5.count == 3)
+ assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+ // start, end are negative, step is positive
+ val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id")
+ assert(res6.count == 2)
+ assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+ val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id")
+ assert(res7.count == 0)
+
+ val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ assert(res8.count == 3)
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+ val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ assert(res9.count == 2)
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org