You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/09/14 06:09:53 UTC

spark git commit: [SPARK-18608][ML][FOLLOWUP] Fix double caching for PySpark OneVsRest.

Repository: spark
Updated Branches:
  refs/heads/master 66cb72d7b -> c76153cc7


[SPARK-18608][ML][FOLLOWUP] Fix double caching for PySpark OneVsRest.

## What changes were proposed in this pull request?
#19197 fixed double caching for MLlib algorithms, but missed PySpark ```OneVsRest```, this PR fixed it.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #19220 from yanboliang/SPARK-18608.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c76153cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c76153cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c76153cc

Branch: refs/heads/master
Commit: c76153cc7dd25b8de5266fe119095066be7f78f5
Parents: 66cb72d
Author: Yanbo Liang <yb...@gmail.com>
Authored: Thu Sep 14 14:09:44 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Thu Sep 14 14:09:44 2017 +0800

----------------------------------------------------------------------
 python/pyspark/ml/classification.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c76153cc/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 0caafa6..27ad1e8 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1773,8 +1773,7 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
             multiclassLabeled = dataset.select(labelCol, featuresCol)
 
         # persist if underlying dataset is not persistent.
-        handlePersistence = \
-            dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False)
+        handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
         if handlePersistence:
             multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
 
@@ -1928,8 +1927,7 @@ class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable):
         newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
 
         # persist if underlying dataset is not persistent.
-        handlePersistence = \
-            dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False)
+        handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
         if handlePersistence:
             newDataset.persist(StorageLevel.MEMORY_AND_DISK)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org