You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/12 22:56:45 UTC

spark git commit: [SPARK-4369] [MLLib] fix TreeModel.predict() with RDD

Repository: spark
Updated Branches:
  refs/heads/master a5ef58113 -> bd86118c4


[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD

Fix  TreeModel.predict() with RDD, added tests for it.

(Also checked that other models don't have this issue)

Author: Davies Liu <da...@databricks.com>

Closes #3230 from davies/predict and squashes the following commits:

81172aa [Davies Liu] fix predict


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd86118c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd86118c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd86118c

Branch: refs/heads/master
Commit: bd86118c4e980f94916f892c76fb808fd4c8bd85
Parents: a5ef581
Author: Davies Liu <da...@databricks.com>
Authored: Wed Nov 12 13:56:41 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Nov 12 13:56:41 2014 -0800

----------------------------------------------------------------------
 .../mllib/tree/model/DecisionTreeModel.scala    | 12 +++++++++
 python/pyspark/mllib/tree.py                    | 26 +++++++++++---------
 2 files changed, 26 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd86118c/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ec1d99a..ac4d02e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.tree.model
 
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.rdd.RDD
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
     features.map(x => predict(x))
   }
 
+
+  /**
+   * Predict values for the given data set using the model trained.
+   *
+   * @param features JavaRDD representing data points to be predicted
+   * @return JavaRDD of predictions for each of the given data points
+   */
+  def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
+    predict(features.rdd)
+  }
+
   /**
    * Get number of nodes in tree, including leaf nodes.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/bd86118c/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d1a3c0..ef0d556 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -124,10 +124,13 @@ class DecisionTree(object):
            Predict: 0.0
           Else (feature 0 > 0.0)
            Predict: 1.0
-        >>> model.predict(array([1.0])) > 0
-        True
-        >>> model.predict(array([0.0])) == 0
-        True
+        >>> model.predict(array([1.0]))
+        1.0
+        >>> model.predict(array([0.0]))
+        0.0
+        >>> rdd = sc.parallelize([[1.0], [0.0]])
+        >>> model.predict(rdd).collect()
+        [1.0, 0.0]
         """
         return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
                                    impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@@ -170,14 +173,13 @@ class DecisionTree(object):
         ... ]
         >>>
         >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
-        >>> model.predict(array([0.0, 1.0])) == 1
-        True
-        >>> model.predict(array([0.0, 0.0])) == 0
-        True
-        >>> model.predict(SparseVector(2, {1: 1.0})) == 1
-        True
-        >>> model.predict(SparseVector(2, {1: 0.0})) == 0
-        True
+        >>> model.predict(SparseVector(2, {1: 1.0}))
+        1.0
+        >>> model.predict(SparseVector(2, {1: 0.0}))
+        0.0
+        >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+        >>> model.predict(rdd).collect()
+        [1.0, 0.0]
         """
         return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
                                    impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org