You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/09/08 19:09:04 UTC

spark git commit: [SPARK-19866][ML][PYSPARK] Add local version of Word2Vec findSynonyms for spark.ml: Python API

Repository: spark
Updated Branches:
  refs/heads/master 8598d03a0 -> 31c74fec2


[SPARK-19866][ML][PYSPARK] Add local version of Word2Vec findSynonyms for spark.ml: Python API

https://issues.apache.org/jira/browse/SPARK-19866

## What changes were proposed in this pull request?

Add Python API for findSynonymsArray matching Scala API.

## How was this patch tested?

Manual test
`./python/run-tests --python-executables=python2.7 --modules=pyspark-ml`

Author: Xin Ren <ia...@126.com>
Author: Xin Ren <re...@gmail.com>
Author: Xin Ren <ke...@users.noreply.github.com>

Closes #17451 from keypointt/SPARK-19866.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31c74fec
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31c74fec
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31c74fec

Branch: refs/heads/master
Commit: 31c74fec24ae3bc8b9eb4ecd90896de459c3cc22
Parents: 8598d03
Author: Xin Ren <ia...@126.com>
Authored: Fri Sep 8 12:09:00 2017 -0700
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Fri Sep 8 12:09:00 2017 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/feature/Word2Vec.scala |  2 +-
 python/pyspark/ml/feature.py                         | 15 +++++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/31c74fec/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index d4c8e4b..f6095e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -229,7 +229,7 @@ class Word2VecModel private[ml] (
    * Find "num" number of words closest in similarity to the given word, not
    * including the word itself.
    * @return a dataframe with columns "word" and "similarity" of the word and the cosine
-   * similarities between the synonyms and the given word vector.
+   * similarities between the synonyms and the given word.
    */
   @Since("1.5.0")
   def findSynonyms(word: String, num: Int): DataFrame = {

http://git-wip-us.apache.org/repos/asf/spark/blob/31c74fec/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 050537b..232ae3e 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2751,6 +2751,8 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
     |   c|[-0.3794820010662...|
     +----+--------------------+
     ...
+    >>> model.findSynonymsArray("a", 2)
+    [(u'b', 0.25053444504737854), (u'c', -0.6980510950088501)]
     >>> from pyspark.sql.functions import format_number as fmt
     >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show()
     +----+----------+
@@ -2927,6 +2929,19 @@ class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
             word = _convert_to_vector(word)
         return self._call_java("findSynonyms", word, num)
 
+    @since("2.3.0")
+    def findSynonymsArray(self, word, num):
+        """
+        Find "num" number of words closest in similarity to "word".
+        word can be a string or vector representation.
+        Returns an array with two fields word and similarity (which
+        gives the cosine similarity).
+        """
+        if not isinstance(word, basestring):
+            word = _convert_to_vector(word)
+        tuples = self._java_obj.findSynonymsArray(word, num)
+        return list(map(lambda st: (st._1(), st._2()), list(tuples)))
+
 
 @inherit_doc
 class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org