You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/09/22 05:12:41 UTC

spark git commit: [SPARK-21981][PYTHON][ML] Added Python interface for ClusteringEvaluator

Repository: spark
Updated Branches:
  refs/heads/master fedf6961b -> 5ac96854c


[SPARK-21981][PYTHON][ML] Added Python interface for ClusteringEvaluator

## What changes were proposed in this pull request?

Added Python interface for ClusteringEvaluator

## How was this patch tested?

Manual test, eg. the example Python code in the comments.

cc yanboliang

Author: Marco Gaido <mg...@hortonworks.com>
Author: Marco Gaido <ma...@gmail.com>

Closes #19204 from mgaido91/SPARK-21981.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ac96854
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ac96854
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ac96854

Branch: refs/heads/master
Commit: 5ac96854cc6186fa2dad602d0906ff2705e3f610
Parents: fedf696
Author: Marco Gaido <mg...@hortonworks.com>
Authored: Fri Sep 22 13:12:33 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Fri Sep 22 13:12:33 2017 +0800

----------------------------------------------------------------------
 python/pyspark/ml/evaluation.py | 76 +++++++++++++++++++++++++++++++++++-
 1 file changed, 74 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ac96854/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 09cdf9b..aa8dbe7 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -20,12 +20,13 @@ from abc import abstractmethod, ABCMeta
 from pyspark import since, keyword_only
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
-from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
+from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \
+    HasFeaturesCol
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 
 __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
-           'MulticlassClassificationEvaluator']
+           'MulticlassClassificationEvaluator', 'ClusteringEvaluator']
 
 
 @inherit_doc
@@ -325,6 +326,77 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
         kwargs = self._input_kwargs
         return self._set(**kwargs)
 
+
+@inherit_doc
+class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
+                          JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Evaluator for Clustering results, which expects two input
+    columns: prediction and features.
+
+    >>> from pyspark.ml.linalg import Vectors
+    >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
+    ...     [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),
+    ...     ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
+    >>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
+    ...
+    >>> evaluator = ClusteringEvaluator(predictionCol="prediction")
+    >>> evaluator.evaluate(dataset)
+    0.9079...
+    >>> ce_path = temp_path + "/ce"
+    >>> evaluator.save(ce_path)
+    >>> evaluator2 = ClusteringEvaluator.load(ce_path)
+    >>> str(evaluator2.getPredictionCol())
+    'prediction'
+
+    .. versionadded:: 2.3.0
+    """
+    metricName = Param(Params._dummy(), "metricName",
+                       "metric name in evaluation (silhouette)",
+                       typeConverter=TypeConverters.toString)
+
+    @keyword_only
+    def __init__(self, predictionCol="prediction", featuresCol="features",
+                 metricName="silhouette"):
+        """
+        __init__(self, predictionCol="prediction", featuresCol="features", \
+                 metricName="silhouette")
+        """
+        super(ClusteringEvaluator, self).__init__()
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
+        self._setDefault(metricName="silhouette")
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    @since("2.3.0")
+    def setMetricName(self, value):
+        """
+        Sets the value of :py:attr:`metricName`.
+        """
+        return self._set(metricName=value)
+
+    @since("2.3.0")
+    def getMetricName(self):
+        """
+        Gets the value of metricName or its default value.
+        """
+        return self.getOrDefault(self.metricName)
+
+    @keyword_only
+    @since("2.3.0")
+    def setParams(self, predictionCol="prediction", featuresCol="features",
+                  metricName="silhouette"):
+        """
+        setParams(self, predictionCol="prediction", featuresCol="features", \
+                  metricName="silhouette")
+        Sets params for clustering evaluator.
+        """
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
 if __name__ == "__main__":
     import doctest
     import tempfile


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org