You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/12/15 22:44:19 UTC

spark git commit: [SPARK-4494][mllib] IDFModel.transform() add support for single vector

Repository: spark
Updated Branches:
  refs/heads/master 4c0673879 -> 8098fab06


[SPARK-4494][mllib] IDFModel.transform() add support for single vector

I improved `IDFModel.transform` to allow using a single vector.

[[SPARK-4494] IDFModel.transform() add support for single vector - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-4494)

Author: Yuu ISHIKAWA <yu...@gmail.com>

Closes #3603 from yu-iskw/idf and squashes the following commits:

256ff3d [Yuu ISHIKAWA] Fix typo
a3bf566 [Yuu ISHIKAWA] - Fix typo - Optimize import order - Aggregate the assertion tests - Modify `IDFModel.transform` API for pyspark
d25e49b [Yuu ISHIKAWA] Add the implementation of `IDFModel.transform` for a term frequency vector


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8098fab0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8098fab0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8098fab0

Branch: refs/heads/master
Commit: 8098fab06cb2be22cca4e531e8e65ab29dbb909a
Parents: 4c06738
Author: Yuu ISHIKAWA <yu...@gmail.com>
Authored: Mon Dec 15 13:44:15 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Dec 15 13:44:15 2014 -0800

----------------------------------------------------------------------
 .../org/apache/spark/mllib/feature/IDF.scala    | 73 ++++++++++++--------
 .../apache/spark/mllib/feature/IDFSuite.scala   | 67 +++++++++++-------
 python/pyspark/mllib/feature.py                 | 22 ++++--
 3 files changed, 101 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8098fab0/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 720bb70..19120e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -174,37 +174,18 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable {
    */
   def transform(dataset: RDD[Vector]): RDD[Vector] = {
     val bcIdf = dataset.context.broadcast(idf)
-    dataset.mapPartitions { iter =>
-      val thisIdf = bcIdf.value
-      iter.map { v =>
-        val n = v.size
-        v match {
-          case sv: SparseVector =>
-            val nnz = sv.indices.size
-            val newValues = new Array[Double](nnz)
-            var k = 0
-            while (k < nnz) {
-              newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
-              k += 1
-            }
-            Vectors.sparse(n, sv.indices, newValues)
-          case dv: DenseVector =>
-            val newValues = new Array[Double](n)
-            var j = 0
-            while (j < n) {
-              newValues(j) = dv.values(j) * thisIdf(j)
-              j += 1
-            }
-            Vectors.dense(newValues)
-          case other =>
-            throw new UnsupportedOperationException(
-              s"Only sparse and dense vectors are supported but got ${other.getClass}.")
-        }
-      }
-    }
+    dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v)))
   }
 
   /**
+   * Transforms a term frequency (TF) vector to a TF-IDF vector
+   *
+   * @param v a term frequency vector
+   * @return a TF-IDF vector
+   */
+  def transform(v: Vector): Vector = IDFModel.transform(idf, v)
+
+  /**
    * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
    * @param dataset a JavaRDD of term frequency vectors
    * @return a JavaRDD of TF-IDF vectors
@@ -213,3 +194,39 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable {
     transform(dataset.rdd).toJavaRDD()
   }
 }
+
+private object IDFModel {
+
+  /**
+   * Transforms a term frequency (TF) vector to a TF-IDF vector with a IDF vector
+   *
+   * @param idf an IDF vector
+   * @param v a term frequence vector
+   * @return a TF-IDF vector
+   */
+  def transform(idf: Vector, v: Vector): Vector = {
+    val n = v.size
+    v match {
+      case sv: SparseVector =>
+        val nnz = sv.indices.size
+        val newValues = new Array[Double](nnz)
+        var k = 0
+        while (k < nnz) {
+          newValues(k) = sv.values(k) * idf(sv.indices(k))
+          k += 1
+        }
+        Vectors.sparse(n, sv.indices, newValues)
+      case dv: DenseVector =>
+        val newValues = new Array[Double](n)
+        var j = 0
+        while (j < n) {
+          newValues(j) = dv.values(j) * idf(j)
+          j += 1
+        }
+        Vectors.dense(newValues)
+      case other =>
+        throw new UnsupportedOperationException(
+          s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8098fab0/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 30147e7..0a5cad7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.feature
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
@@ -41,18 +40,26 @@ class IDFSuite extends FunSuite with MLlibTestSparkContext {
       math.log((m + 1.0) / (x + 1.0))
     })
     assert(model.idf ~== expected absTol 1e-12)
-    val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
-    assert(tfidf.size === 3)
-    val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
-    assert(tfidf0.indices === Array(1, 3))
-    assert(Vectors.dense(tfidf0.values) ~==
-      Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
-    val tfidf1 = tfidf(1L).asInstanceOf[DenseVector]
-    assert(Vectors.dense(tfidf1.values) ~==
-      Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
-    val tfidf2 = tfidf(2L).asInstanceOf[SparseVector]
-    assert(tfidf2.indices === Array(1))
-    assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
+
+    val assertHelper = (tfidf: Array[Vector]) => {
+      assert(tfidf.size === 3)
+      val tfidf0 = tfidf(0).asInstanceOf[SparseVector]
+      assert(tfidf0.indices === Array(1, 3))
+      assert(Vectors.dense(tfidf0.values) ~==
+          Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
+      val tfidf1 = tfidf(1).asInstanceOf[DenseVector]
+      assert(Vectors.dense(tfidf1.values) ~==
+          Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
+      val tfidf2 = tfidf(2).asInstanceOf[SparseVector]
+      assert(tfidf2.indices === Array(1))
+      assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
+    }
+    // Transforms a RDD
+    val tfidf = model.transform(termFrequencies).collect()
+    assertHelper(tfidf)
+    // Transforms local vectors
+    val localTfidf = localTermFrequencies.map(model.transform(_)).toArray
+    assertHelper(localTfidf)
   }
 
   test("idf minimum document frequency filtering") {
@@ -74,18 +81,26 @@ class IDFSuite extends FunSuite with MLlibTestSparkContext {
       }
     })
     assert(model.idf ~== expected absTol 1e-12)
-    val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
-    assert(tfidf.size === 3)
-    val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
-    assert(tfidf0.indices === Array(1, 3))
-    assert(Vectors.dense(tfidf0.values) ~==
-      Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
-    val tfidf1 = tfidf(1L).asInstanceOf[DenseVector]
-    assert(Vectors.dense(tfidf1.values) ~==
-      Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
-    val tfidf2 = tfidf(2L).asInstanceOf[SparseVector]
-    assert(tfidf2.indices === Array(1))
-    assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
+
+    val assertHelper = (tfidf: Array[Vector]) => {
+      assert(tfidf.size === 3)
+      val tfidf0 = tfidf(0).asInstanceOf[SparseVector]
+      assert(tfidf0.indices === Array(1, 3))
+      assert(Vectors.dense(tfidf0.values) ~==
+          Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
+      val tfidf1 = tfidf(1).asInstanceOf[DenseVector]
+      assert(Vectors.dense(tfidf1.values) ~==
+          Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
+      val tfidf2 = tfidf(2).asInstanceOf[SparseVector]
+      assert(tfidf2.indices === Array(1))
+      assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
+    }
+    // Transforms a RDD
+    val tfidf = model.transform(termFrequencies).collect()
+    assertHelper(tfidf)
+    // Transforms local vectors
+    val localTfidf = localTermFrequencies.map(model.transform(_)).toArray
+    assertHelper(localTfidf)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8098fab0/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 8cb992d..741c630 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -28,7 +28,7 @@ from py4j.protocol import Py4JJavaError
 
 from pyspark import RDD, SparkContext
 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
 
 __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
            'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@@ -212,7 +212,7 @@ class IDFModel(JavaVectorTransformer):
     """
     Represents an IDF model that can transform term frequency vectors.
     """
-    def transform(self, dataset):
+    def transform(self, x):
         """
         Transforms term frequency (TF) vectors to TF-IDF vectors.
 
@@ -220,12 +220,14 @@ class IDFModel(JavaVectorTransformer):
         the terms which occur in fewer than `minDocFreq`
         documents will have an entry of 0.
 
-        :param dataset: an RDD of term frequency vectors
-        :return: an RDD of TF-IDF vectors
+        :param x: an RDD of term frequency vectors or a term frequency vector
+        :return: an RDD of TF-IDF vectors or a TF-IDF vector
         """
-        if not isinstance(dataset, RDD):
-            raise TypeError("dataset should be an RDD of term frequency vectors")
-        return JavaVectorTransformer.transform(self, dataset)
+        if isinstance(x, RDD):
+            return JavaVectorTransformer.transform(self, x)
+
+        x = _convert_to_vector(x)
+        return JavaVectorTransformer.transform(self, x)
 
 
 class IDF(object):
@@ -255,6 +257,12 @@ class IDF(object):
     SparseVector(4, {1: 0.0, 3: 0.5754})
     DenseVector([0.0, 0.0, 1.3863, 0.863])
     SparseVector(4, {1: 0.0})
+    >>> model.transform(Vectors.dense([0.0, 1.0, 2.0, 3.0]))
+    DenseVector([0.0, 0.0, 1.3863, 0.863])
+    >>> model.transform([0.0, 1.0, 2.0, 3.0])
+    DenseVector([0.0, 0.0, 1.3863, 0.863])
+    >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0)))
+    SparseVector(4, {1: 0.0, 3: 0.5754})
     """
     def __init__(self, minDocFreq=0):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org