You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/01/19 15:28:38 UTC

spark git commit: [SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse

Repository: spark
Updated Branches:
  refs/heads/master 6c39654ef -> 606a7485f


[SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse

## What changes were proposed in this pull request?
`ML.Vectors#sparse(size: Int, elements: Seq[(Int, Double)])` support zero-length

## How was this patch tested?
existing tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #20275 from zhengruifeng/SparseVector_size.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/606a7485
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/606a7485
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/606a7485

Branch: refs/heads/master
Commit: 606a7485f12c5d5377c50258006c353ba5e49c3f
Parents: 6c39654
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Fri Jan 19 09:28:35 2018 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Jan 19 09:28:35 2018 -0600

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/linalg/Vectors.scala    |  2 +-
 .../org/apache/spark/ml/linalg/VectorsSuite.scala     | 14 ++++++++++++++
 .../scala/org/apache/spark/mllib/linalg/Vectors.scala |  3 +--
 .../org/apache/spark/mllib/linalg/VectorsSuite.scala  | 14 ++++++++++++++
 4 files changed, 30 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/606a7485/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 941b6ec..5824e46 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") (
 
   // validate the data
   {
-    require(size >= 0, "The size of the requested sparse vector must be greater than 0.")
+    require(size >= 0, "The size of the requested sparse vector must be no less than 0.")
     require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
       s" indices match the dimension of the values. You provided ${indices.length} indices and " +
       s" ${values.length} values.")

http://git-wip-us.apache.org/repos/asf/spark/blob/606a7485/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index 79acef8..0a316f5 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite {
     assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
     assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
   }
+
+  test("sparse vector only support non-negative length") {
+    val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray)
+    val v2 = Vectors.sparse(0, Array.empty[(Int, Double)])
+    assert(v1.size === 0)
+    assert(v2.size === 0)
+
+    intercept[IllegalArgumentException] {
+      Vectors.sparse(-1, Array(1), Array(2.0))
+    }
+    intercept[IllegalArgumentException] {
+      Vectors.sparse(-1, Array((1, 2.0)))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/606a7485/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index fd9605c..6e68d96 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -326,8 +326,6 @@ object Vectors {
    */
   @Since("1.0.0")
   def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
-    require(size > 0, "The size of the requested sparse vector must be greater than 0.")
-
     val (indices, values) = elements.sortBy(_._1).unzip
     var prev = -1
     indices.foreach { i =>
@@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") (
     @Since("1.0.0") val indices: Array[Int],
     @Since("1.0.0") val values: Array[Double]) extends Vector {
 
+  require(size >= 0, "The size of the requested sparse vector must be no less than 0.")
   require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
     s" indices match the dimension of the values. You provided ${indices.length} indices and " +
     s" ${values.length} values.")

http://git-wip-us.apache.org/repos/asf/spark/blob/606a7485/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 4074bea..217b4a3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging {
     assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV))
     assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV))
   }
+
+  test("sparse vector only support non-negative length") {
+    val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray)
+    val v2 = Vectors.sparse(0, Array.empty[(Int, Double)])
+    assert(v1.size === 0)
+    assert(v2.size === 0)
+
+    intercept[IllegalArgumentException] {
+      Vectors.sparse(-1, Array(1), Array(2.0))
+    }
+    intercept[IllegalArgumentException] {
+      Vectors.sparse(-1, Array((1, 2.0)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org