You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/06 10:28:53 UTC

spark git commit: [SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API

Repository: spark
Updated Branches:
  refs/heads/master 9f019c722 -> 32cdc815c


[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API

Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #5926 from mengxr/SPARK-6940 and squashes the following commits:

6af181f [Xiangrui Meng] add TODOs
8285134 [Xiangrui Meng] update doc
060f7c3 [Xiangrui Meng] update doctest
acac727 [Xiangrui Meng] add keyword args
cdddecd [Xiangrui Meng] add CrossValidator in Python


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32cdc815
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32cdc815
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32cdc815

Branch: refs/heads/master
Commit: 32cdc815c6fc19b5c8c4eca35f88a61302d67cd5
Parents: 9f019c7
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed May 6 01:28:43 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed May 6 01:28:43 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala |   7 +-
 python/pyspark/ml/pipeline.py                   |  13 +-
 python/pyspark/ml/tuning.py                     | 183 ++++++++++++++++++-
 python/pyspark/ml/wrapper.py                    |   4 +-
 4 files changed, 199 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cee2aa6..9208127 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params {
   def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
 
   /**
-   * param for the evaluator for selection
+   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
+   * metric
    * @group param
    */
-  val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
@@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
       trainingDataset.unpersist()
       var i = 0
       while (i < numModels) {
+        // TODO: duplicate evaluator to take extra params from input
         val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
         logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
         metrics(i) += metric

http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 7b875e4..c1b2077 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only
 from pyspark.mllib.common import inherit_doc
 
 
-__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']
 
 
 @inherit_doc
@@ -71,6 +71,15 @@ class Transformer(Params):
 
 
 @inherit_doc
+class Model(Transformer):
+    """
+    Abstract class for models that are fitted by estimators.
+    """
+
+    __metaclass__ = ABCMeta
+
+
+@inherit_doc
 class Pipeline(Estimator):
     """
     A simple pipeline, which acts as an estimator. A Pipeline consists
@@ -154,7 +163,7 @@ class Pipeline(Estimator):
 
 
 @inherit_doc
-class PipelineModel(Transformer):
+class PipelineModel(Model):
     """
     Represents a compiled pipeline with transformers and fitted models.
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 1773ab5..f6cf2c3 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -16,8 +16,14 @@
 #
 
 import itertools
+import numpy as np
 
-__all__ = ['ParamGridBuilder']
+from pyspark.ml.param import Params, Param
+from pyspark.ml import Estimator, Model
+from pyspark.ml.util import keyword_only
+from pyspark.sql.functions import rand
+
+__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
 
 
 class ParamGridBuilder(object):
@@ -79,6 +85,179 @@ class ParamGridBuilder(object):
         return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
 
 
+class CrossValidator(Estimator):
+    """
+    K-fold cross validation.
+
+    >>> from pyspark.ml.classification import LogisticRegression
+    >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> dataset = sqlContext.createDataFrame(
+    ...     [(Vectors.dense([0.0, 1.0]), 0.0),
+    ...      (Vectors.dense([1.0, 2.0]), 1.0),
+    ...      (Vectors.dense([0.55, 3.0]), 0.0),
+    ...      (Vectors.dense([0.45, 4.0]), 1.0),
+    ...      (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+    ...     ["features", "label"])
+    >>> lr = LogisticRegression()
+    >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+    >>> evaluator = BinaryClassificationEvaluator()
+    >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+    >>> cvModel = cv.fit(dataset)
+    >>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
+    >>> cvModel.transform(dataset).collect() == expected.collect()
+    True
+    """
+
+    # a placeholder to make it appear in the generated doc
+    estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+
+    # a placeholder to make it appear in the generated doc
+    estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+
+    # a placeholder to make it appear in the generated doc
+    evaluator = Param(
+        Params._dummy(), "evaluator",
+        "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+
+    # a placeholder to make it appear in the generated doc
+    numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
+
+    @keyword_only
+    def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        """
+        __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
+        """
+        super(CrossValidator, self).__init__()
+        #: param for estimator to be cross-validated
+        self.estimator = Param(self, "estimator", "estimator to be cross-validated")
+        #: param for estimator param maps
+        self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
+        #: param for the evaluator used to select hyper-parameters that
+        #: maximize the cross-validated metric
+        self.evaluator = Param(
+            self, "evaluator",
+            "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+        #: param for number of folds for cross validation
+        self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
+        self._setDefault(numFolds=3)
+        kwargs = self.__init__._input_kwargs
+        self._set(**kwargs)
+
+    @keyword_only
+    def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        """
+        setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        Sets params for cross validator.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def setEstimator(self, value):
+        """
+        Sets the value of :py:attr:`estimator`.
+        """
+        self.paramMap[self.estimator] = value
+        return self
+
+    def getEstimator(self):
+        """
+        Gets the value of estimator or its default value.
+        """
+        return self.getOrDefault(self.estimator)
+
+    def setEstimatorParamMaps(self, value):
+        """
+        Sets the value of :py:attr:`estimatorParamMaps`.
+        """
+        self.paramMap[self.estimatorParamMaps] = value
+        return self
+
+    def getEstimatorParamMaps(self):
+        """
+        Gets the value of estimatorParamMaps or its default value.
+        """
+        return self.getOrDefault(self.estimatorParamMaps)
+
+    def setEvaluator(self, value):
+        """
+        Sets the value of :py:attr:`evaluator`.
+        """
+        self.paramMap[self.evaluator] = value
+        return self
+
+    def getEvaluator(self):
+        """
+        Gets the value of evaluator or its default value.
+        """
+        return self.getOrDefault(self.evaluator)
+
+    def setNumFolds(self, value):
+        """
+        Sets the value of :py:attr:`numFolds`.
+        """
+        self.paramMap[self.numFolds] = value
+        return self
+
+    def getNumFolds(self):
+        """
+        Gets the value of numFolds or its default value.
+        """
+        return self.getOrDefault(self.numFolds)
+
+    def fit(self, dataset, params={}):
+        paramMap = self.extractParamMap(params)
+        est = paramMap[self.estimator]
+        epm = paramMap[self.estimatorParamMaps]
+        numModels = len(epm)
+        eva = paramMap[self.evaluator]
+        nFolds = paramMap[self.numFolds]
+        h = 1.0 / nFolds
+        randCol = self.uid + "_rand"
+        df = dataset.select("*", rand(0).alias(randCol))
+        metrics = np.zeros(numModels)
+        for i in range(nFolds):
+            validateLB = i * h
+            validateUB = (i + 1) * h
+            condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
+            validation = df.filter(condition)
+            train = df.filter(~condition)
+            for j in range(numModels):
+                model = est.fit(train, epm[j])
+                # TODO: duplicate evaluator to take extra params from input
+                metric = eva.evaluate(model.transform(validation, epm[j]))
+                metrics[j] += metric
+        bestIndex = np.argmax(metrics)
+        bestModel = est.fit(dataset, epm[bestIndex])
+        return CrossValidatorModel(bestModel)
+
+
+class CrossValidatorModel(Model):
+    """
+    Model from k-fold cross validation.
+    """
+
+    def __init__(self, bestModel):
+        #: best model from cross validation
+        self.bestModel = bestModel
+
+    def transform(self, dataset, params={}):
+        return self.bestModel.transform(dataset, params)
+
+
 if __name__ == "__main__":
     import doctest
-    doctest.testmod()
+    from pyspark.context import SparkContext
+    from pyspark.sql import SQLContext
+    globs = globals().copy()
+    # The small batch size here ensures that we see multiple batches,
+    # even in these small test examples:
+    sc = SparkContext("local[2]", "ml.tuning tests")
+    sqlContext = SQLContext(sc)
+    globs['sc'] = sc
+    globs['sqlContext'] = sqlContext
+    (failure_count, test_count) = doctest.testmod(
+        globs=globs, optionflags=doctest.ELLIPSIS)
+    sc.stop()
+    if failure_count:
+        exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 73741c4..0634254 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -20,7 +20,7 @@ from abc import ABCMeta
 from pyspark import SparkContext
 from pyspark.sql import DataFrame
 from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
+from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
 from pyspark.mllib.common import inherit_doc
 
 
@@ -133,7 +133,7 @@ class JavaTransformer(Transformer, JavaWrapper):
 
 
 @inherit_doc
-class JavaModel(JavaTransformer):
+class JavaModel(Model, JavaTransformer):
     """
     Base class for :py:class:`Model`s that wrap Java/Scala
     implementations.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org