You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/10/14 11:17:10 UTC

spark git commit: [SPARK-15402][ML][PYSPARK] PySpark ml.evaluation should support save/load

Repository: spark
Updated Branches:
  refs/heads/master 2fb12b0a3 -> 1db8feab8


[SPARK-15402][ML][PYSPARK] PySpark ml.evaluation should support save/load

## What changes were proposed in this pull request?
Since ```ml.evaluation``` has supported save/load at Scala side, supporting it at Python side is very straightforward and easy.

## How was this patch tested?
Add python doctest.

Author: Yanbo Liang <yb...@gmail.com>

Closes #13194 from yanboliang/spark-15402.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1db8feab
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1db8feab
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1db8feab

Branch: refs/heads/master
Commit: 1db8feab8c564053c05e8bdc1a7f5026fd637d4f
Parents: 2fb12b0
Author: Yanbo Liang <yb...@gmail.com>
Authored: Fri Oct 14 04:17:03 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Fri Oct 14 04:17:03 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/evaluation.py | 45 ++++++++++++++++++++++++++++--------
 1 file changed, 36 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1db8feab/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 1fe8772..7aa16fa 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -22,6 +22,7 @@ from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
 from pyspark.ml.common import inherit_doc
+from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 
 __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
            'MulticlassClassificationEvaluator']
@@ -103,7 +104,8 @@ class JavaEvaluator(JavaParams, Evaluator):
 
 
 @inherit_doc
-class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
+class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
+                                    JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
 
@@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
     0.70...
     >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
     0.83...
+    >>> bce_path = temp_path + "/bce"
+    >>> evaluator.save(bce_path)
+    >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
+    >>> str(evaluator2.getRawPredictionCol())
+    'raw'
 
     .. versionadded:: 1.4.0
     """
@@ -172,7 +179,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
 
 
 @inherit_doc
-class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+                          JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
 
@@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
     0.993...
     >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
     2.649...
+    >>> re_path = temp_path + "/re"
+    >>> evaluator.save(re_path)
+    >>> evaluator2 = RegressionEvaluator.load(re_path)
+    >>> str(evaluator2.getPredictionCol())
+    'raw'
 
     .. versionadded:: 1.4.0
     """
@@ -244,7 +257,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
 
 
 @inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+                                        JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
 
@@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
     0.66...
     >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
     0.66...
+    >>> mce_path = temp_path + "/mce"
+    >>> evaluator.save(mce_path)
+    >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
+    >>> str(evaluator2.getPredictionCol())
+    'prediction'
 
     .. versionadded:: 1.5.0
     """
@@ -311,19 +330,27 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
 if __name__ == "__main__":
     import doctest
+    import tempfile
+    import pyspark.ml.evaluation
     from pyspark.sql import SparkSession
-    globs = globals().copy()
+    globs = pyspark.ml.evaluation.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
     spark = SparkSession.builder\
         .master("local[2]")\
         .appName("ml.evaluation tests")\
         .getOrCreate()
-    sc = spark.sparkContext
-    globs['sc'] = sc
     globs['spark'] = spark
-    (failure_count, test_count) = doctest.testmod(
-        globs=globs, optionflags=doctest.ELLIPSIS)
-    spark.stop()
+    temp_path = tempfile.mkdtemp()
+    globs['temp_path'] = temp_path
+    try:
+        (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+        spark.stop()
+    finally:
+        from shutil import rmtree
+        try:
+            rmtree(temp_path)
+        except OSError:
+            pass
     if failure_count:
         exit(-1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org