You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/29 02:14:26 UTC

spark git commit: [SPARK-4586][MLLIB] Python API for ML pipeline and parameters

Repository: spark
Updated Branches:
  refs/heads/master e023112d3 -> e80dc1c5a


[SPARK-4586][MLLIB] Python API for ML pipeline and parameters

This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code.

TODO:
- [x] handle parameters in LRModel
- [x] unit tests
- [x] missing some docs

CC: davies jkbradley

Author: Xiangrui Meng <me...@databricks.com>
Author: Davies Liu <da...@databricks.com>

Closes #4151 from mengxr/SPARK-4586 and squashes the following commits:

415268e [Xiangrui Meng] remove inherit_doc from __init__
edbd6fe [Xiangrui Meng] move Identifiable to ml.util
44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml
dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586
14ae7e2 [Davies Liu] fix docs
54ca7df [Davies Liu] fix tests
78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml
fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586
1dca16a [Davies Liu] refactor
090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml
0882513 [Xiangrui Meng] update doc style
a4f4dbf [Xiangrui Meng] add unit test for LR
7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer
ba0ba1e [Xiangrui Meng] add unit tests for pipeline
0586c7b [Xiangrui Meng] add more comments to the example
5153cff [Xiangrui Meng] simplify java models
036ca04 [Xiangrui Meng] gen numFeatures
46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly
1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc
f66ba0c [Xiangrui Meng] make params a property
d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute
f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example
05e3e40 [Xiangrui Meng] update example
d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl
56de571 [Xiangrui Meng] fix style
d0c5bb8 [Xiangrui Meng] a working copy
bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586
17ecfb9 [Xiangrui Meng] code gen for shared params
d9ea77c [Xiangrui Meng] update doc
c18dca1 [Xiangrui Meng] make the example working
dadd84e [Xiangrui Meng] add base classes and docs
a3015cf [Xiangrui Meng] add Estimator and Transformer
46eea43 [Xiangrui Meng] a pipeline in python
33b68e0 [Xiangrui Meng] a working LR


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e80dc1c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e80dc1c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e80dc1c5

Branch: refs/heads/master
Commit: e80dc1c5a80cddba8b367cf5cdf9f71df5d87250
Parents: e023112
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Jan 28 17:14:23 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jan 28 17:14:23 2015 -0800

----------------------------------------------------------------------
 .../ml/simple_text_classification_pipeline.py   |  79 ++++++
 mllib/pom.xml                                   |   2 +
 .../org/apache/spark/ml/param/params.scala      |   8 +-
 python/docs/conf.py                             |   4 +-
 python/docs/index.rst                           |   1 +
 python/docs/pyspark.ml.rst                      |  29 +++
 python/docs/pyspark.rst                         |   1 +
 python/pyspark/ml/__init__.py                   |  21 ++
 python/pyspark/ml/classification.py             |  76 ++++++
 python/pyspark/ml/feature.py                    |  82 ++++++
 python/pyspark/ml/param/__init__.py             |  82 ++++++
 python/pyspark/ml/param/_gen_shared_params.py   |  98 +++++++
 python/pyspark/ml/param/shared.py               | 260 +++++++++++++++++++
 python/pyspark/ml/pipeline.py                   | 154 +++++++++++
 python/pyspark/ml/tests.py                      | 115 ++++++++
 python/pyspark/ml/util.py                       |  46 ++++
 python/pyspark/ml/wrapper.py                    | 149 +++++++++++
 python/pyspark/sql.py                           |  14 -
 python/run-tests                                |   8 +
 19 files changed, 1212 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/examples/src/main/python/ml/simple_text_classification_pipeline.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py
new file mode 100644
index 0000000..c7df3d7
--- /dev/null
+++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext, Row
+from pyspark.ml import Pipeline
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.classification import LogisticRegression
+
+
+"""
+A simple text classification pipeline that recognizes "spark" from
+input text. This is to show how to create and configure a Spark ML
+pipeline in Python. Run with:
+
+  bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py
+"""
+
+
+if __name__ == "__main__":
+    sc = SparkContext(appName="SimpleTextClassificationPipeline")
+    sqlCtx = SQLContext(sc)
+
+    # Prepare training documents, which are labeled.
+    LabeledDocument = Row('id', 'text', 'label')
+    training = sqlCtx.inferSchema(
+        sc.parallelize([(0L, "a b c d e spark", 1.0),
+                        (1L, "b d", 0.0),
+                        (2L, "spark f g h", 1.0),
+                        (3L, "hadoop mapreduce", 0.0)])
+          .map(lambda x: LabeledDocument(*x)))
+
+    # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+    tokenizer = Tokenizer() \
+        .setInputCol("text") \
+        .setOutputCol("words")
+    hashingTF = HashingTF() \
+        .setInputCol(tokenizer.getOutputCol()) \
+        .setOutputCol("features")
+    lr = LogisticRegression() \
+        .setMaxIter(10) \
+        .setRegParam(0.01)
+    pipeline = Pipeline() \
+        .setStages([tokenizer, hashingTF, lr])
+
+    # Fit the pipeline to training documents.
+    model = pipeline.fit(training)
+
+    # Prepare test documents, which are unlabeled.
+    Document = Row('id', 'text')
+    test = sqlCtx.inferSchema(
+        sc.parallelize([(4L, "spark i j k"),
+                        (5L, "l m n"),
+                        (6L, "mapreduce spark"),
+                        (7L, "apache hadoop")])
+          .map(lambda x: Document(*x)))
+
+    # Make predictions on test documents and print columns of interest.
+    prediction = model.transform(test)
+    prediction.registerTempTable("prediction")
+    selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
+    for row in selected.collect():
+        print row
+
+    sc.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index a0bda89..7b7beaf 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -125,6 +125,8 @@
         <directory>../python</directory>
         <includes>
           <include>pyspark/mllib/*.py</include>
+          <include>pyspark/ml/*.py</include>
+          <include>pyspark/ml/param/*.py</include>
         </includes>
       </resource>
     </resources>

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 04f9cfb..5fb4379 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -165,6 +165,13 @@ trait Params extends Identifiable with Serializable {
   }
 
   /**
+   * Sets a parameter (by name) in the embedded param map.
+   */
+  private[ml] def set(param: String, value: Any): this.type = {
+    set(getParam(param), value)
+  }
+
+  /**
    * Gets the value of a parameter in the embedded param map.
    */
   private[ml] def get[T](param: Param[T]): T = {
@@ -286,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
     new ParamMap(this.map ++ other.map)
   }
 
-
   /**
    * Adds all parameters from the input param map into this param map.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/docs/conf.py
----------------------------------------------------------------------
diff --git a/python/docs/conf.py b/python/docs/conf.py
index e58d97a..b00dce9 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -55,9 +55,9 @@ copyright = u'2014, Author'
 # built documents.
 #
 # The short X.Y version.
-version = '1.2-SNAPSHOT'
+version = '1.3-SNAPSHOT'
 # The full version, including alpha/beta/rc tags.
-release = '1.2-SNAPSHOT'
+release = '1.3-SNAPSHOT'
 
 # The language for content autogenerated by Sphinx. Refer to documentation
 # for a list of supported languages.

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/docs/index.rst
----------------------------------------------------------------------
diff --git a/python/docs/index.rst b/python/docs/index.rst
index 703bef6..d150de9 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -14,6 +14,7 @@ Contents:
    pyspark
    pyspark.sql
    pyspark.streaming
+   pyspark.ml
    pyspark.mllib
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/docs/pyspark.ml.rst
----------------------------------------------------------------------
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
new file mode 100644
index 0000000..f10d133
--- /dev/null
+++ b/python/docs/pyspark.ml.rst
@@ -0,0 +1,29 @@
+pyspark.ml package
+=====================
+
+Submodules
+----------
+
+pyspark.ml module
+-----------------
+
+.. automodule:: pyspark.ml
+    :members:
+    :undoc-members:
+    :inherited-members:
+
+pyspark.ml.feature module
+-------------------------
+
+.. automodule:: pyspark.ml.feature
+    :members:
+    :undoc-members:
+    :inherited-members:
+
+pyspark.ml.classification module
+--------------------------------
+
+.. automodule:: pyspark.ml.classification
+    :members:
+    :undoc-members:
+    :inherited-members:

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/docs/pyspark.rst
----------------------------------------------------------------------
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index e81be3b..0df12c4 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -9,6 +9,7 @@ Subpackages
     
     pyspark.sql
     pyspark.streaming
+    pyspark.ml
     pyspark.mllib
 
 Contents

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
new file mode 100644
index 0000000..47fed80
--- /dev/null
+++ b/python/pyspark/ml/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param import *
+from pyspark.ml.pipeline import *
+
+__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
new file mode 100644
index 0000000..6bd2aa8
--- /dev/null
+++ b/python/pyspark/ml/classification.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
+    HasRegParam
+
+
+__all__ = ['LogisticRegression', 'LogisticRegressionModel']
+
+
+@inherit_doc
+class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
+                         HasRegParam):
+    """
+    Logistic regression.
+
+    >>> from pyspark.sql import Row
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
+            Row(label=1.0, features=Vectors.dense(1.0)), \
+            Row(label=0.0, features=Vectors.sparse(1, [], []))]))
+    >>> lr = LogisticRegression() \
+            .setMaxIter(5) \
+            .setRegParam(0.01)
+    >>> model = lr.fit(dataset)
+    >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
+    >>> print model.transform(test0).head().prediction
+    0.0
+    >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
+    >>> print model.transform(test1).head().prediction
+    1.0
+    """
+    _java_class = "org.apache.spark.ml.classification.LogisticRegression"
+
+    def _create_model(self, java_model):
+        return LogisticRegressionModel(java_model)
+
+
+class LogisticRegressionModel(JavaModel):
+    """
+    Model fitted by LogisticRegression.
+    """
+
+
+if __name__ == "__main__":
+    import doctest
+    from pyspark.context import SparkContext
+    from pyspark.sql import SQLContext
+    globs = globals().copy()
+    # The small batch size here ensures that we see multiple batches,
+    # even in these small test examples:
+    sc = SparkContext("local[2]", "ml.feature tests")
+    sqlCtx = SQLContext(sc)
+    globs['sc'] = sc
+    globs['sqlCtx'] = sqlCtx
+    (failure_count, test_count) = doctest.testmod(
+        globs=globs, optionflags=doctest.ELLIPSIS)
+    sc.stop()
+    if failure_count:
+        exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
new file mode 100644
index 0000000..e088acd
--- /dev/null
+++ b/python/pyspark/ml/feature.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaTransformer
+
+__all__ = ['Tokenizer', 'HashingTF']
+
+
+@inherit_doc
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
+    """
+    A tokenizer that converts the input string to lowercase and then
+    splits it by white spaces.
+
+    >>> from pyspark.sql import Row
+    >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
+    >>> tokenizer = Tokenizer() \
+            .setInputCol("text") \
+            .setOutputCol("words")
+    >>> print tokenizer.transform(dataset).head()
+    Row(text=u'a b c', words=[u'a', u'b', u'c'])
+    >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
+    Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
+    """
+
+    _java_class = "org.apache.spark.ml.feature.Tokenizer"
+
+
+@inherit_doc
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
+    """
+    Maps a sequence of terms to their term frequencies using the
+    hashing trick.
+
+    >>> from pyspark.sql import Row
+    >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
+    >>> hashingTF = HashingTF() \
+            .setNumFeatures(10) \
+            .setInputCol("words") \
+            .setOutputCol("features")
+    >>> print hashingTF.transform(dataset).head().features
+    (10,[7,8,9],[1.0,1.0,1.0])
+    >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
+    >>> print hashingTF.transform(dataset, params).head().vector
+    (5,[2,3,4],[1.0,1.0,1.0])
+    """
+
+    _java_class = "org.apache.spark.ml.feature.HashingTF"
+
+
+if __name__ == "__main__":
+    import doctest
+    from pyspark.context import SparkContext
+    from pyspark.sql import SQLContext
+    globs = globals().copy()
+    # The small batch size here ensures that we see multiple batches,
+    # even in these small test examples:
+    sc = SparkContext("local[2]", "ml.feature tests")
+    sqlCtx = SQLContext(sc)
+    globs['sc'] = sc
+    globs['sqlCtx'] = sqlCtx
+    (failure_count, test_count) = doctest.testmod(
+        globs=globs, optionflags=doctest.ELLIPSIS)
+    sc.stop()
+    if failure_count:
+        exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
new file mode 100644
index 0000000..5566792
--- /dev/null
+++ b/python/pyspark/ml/param/__init__.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark.ml.util import Identifiable
+
+
+__all__ = ['Param', 'Params']
+
+
+class Param(object):
+    """
+    A param with self-contained documentation and optionally default value.
+    """
+
+    def __init__(self, parent, name, doc, defaultValue=None):
+        if not isinstance(parent, Identifiable):
+            raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+        self.parent = parent
+        self.name = str(name)
+        self.doc = str(doc)
+        self.defaultValue = defaultValue
+
+    def __str__(self):
+        return str(self.parent) + "-" + self.name
+
+    def __repr__(self):
+        return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
+               (self.parent, self.name, self.doc, self.defaultValue)
+
+
+class Params(Identifiable):
+    """
+    Components that take parameters. This also provides an internal
+    param map to store parameter values attached to the instance.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self):
+        super(Params, self).__init__()
+        #: embedded param map
+        self.paramMap = {}
+
+    @property
+    def params(self):
+        """
+        Returns all params. The default implementation uses
+        :py:func:`dir` to get all attributes of type
+        :py:class:`Param`.
+        """
+        return filter(lambda attr: isinstance(attr, Param),
+                      [getattr(self, x) for x in dir(self) if x != "params"])
+
+    def _merge_params(self, params):
+        paramMap = self.paramMap.copy()
+        paramMap.update(params)
+        return paramMap
+
+    @staticmethod
+    def _dummy():
+        """
+        Returns a dummy Params instance used as a placeholder to generate docs.
+        """
+        dummy = Params()
+        dummy.uid = "undefined"
+        return dummy

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/param/_gen_shared_params.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py
new file mode 100644
index 0000000..5eb8110
--- /dev/null
+++ b/python/pyspark/ml/param/_gen_shared_params.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+header = """#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#"""
+
+
+def _gen_param_code(name, doc, defaultValue):
+    """
+    Generates Python code for a shared param class.
+
+    :param name: param name
+    :param doc: param doc
+    :param defaultValue: string representation of the param
+    :return: code string
+    """
+    # TODO: How to correctly inherit instance attributes?
+    template = '''class Has$Name(Params):
+    """
+    Params with $name.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    $name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
+
+    def __init__(self):
+        super(Has$Name, self).__init__()
+        #: param for $doc
+        self.$name = Param(self, "$name", "$doc", $defaultValue)
+
+    def set$Name(self, value):
+        """
+        Sets the value of :py:attr:`$name`.
+        """
+        self.paramMap[self.$name] = value
+        return self
+
+    def get$Name(self):
+        """
+        Gets the value of $name or its default value.
+        """
+        if self.$name in self.paramMap:
+            return self.paramMap[self.$name]
+        else:
+            return self.$name.defaultValue'''
+
+    upperCamelName = name[0].upper() + name[1:]
+    return template \
+        .replace("$name", name) \
+        .replace("$Name", upperCamelName) \
+        .replace("$doc", doc) \
+        .replace("$defaultValue", defaultValue)
+
+if __name__ == "__main__":
+    print header
+    print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
+    print "from pyspark.ml.param import Param, Params\n\n"
+    shared = [
+        ("maxIter", "max number of iterations", "100"),
+        ("regParam", "regularization constant", "0.1"),
+        ("featuresCol", "features column name", "'features'"),
+        ("labelCol", "label column name", "'label'"),
+        ("predictionCol", "prediction column name", "'prediction'"),
+        ("inputCol", "input column name", "'input'"),
+        ("outputCol", "output column name", "'output'"),
+        ("numFeatures", "number of features", "1 << 18")]
+    code = []
+    for name, doc, defaultValue in shared:
+        code.append(_gen_param_code(name, doc, defaultValue))
+    print "\n\n\n".join(code)

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
new file mode 100644
index 0000000..586822f
--- /dev/null
+++ b/python/pyspark/ml/param/shared.py
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
+
+from pyspark.ml.param import Param, Params
+
+
+class HasMaxIter(Params):
+    """
+    Params with maxIter.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100)
+
+    def __init__(self):
+        super(HasMaxIter, self).__init__()
+        #: param for max number of iterations
+        self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
+
+    def setMaxIter(self, value):
+        """
+        Sets the value of :py:attr:`maxIter`.
+        """
+        self.paramMap[self.maxIter] = value
+        return self
+
+    def getMaxIter(self):
+        """
+        Gets the value of maxIter or its default value.
+        """
+        if self.maxIter in self.paramMap:
+            return self.paramMap[self.maxIter]
+        else:
+            return self.maxIter.defaultValue
+
+
+class HasRegParam(Params):
+    """
+    Params with regParam.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1)
+
+    def __init__(self):
+        super(HasRegParam, self).__init__()
+        #: param for regularization constant
+        self.regParam = Param(self, "regParam", "regularization constant", 0.1)
+
+    def setRegParam(self, value):
+        """
+        Sets the value of :py:attr:`regParam`.
+        """
+        self.paramMap[self.regParam] = value
+        return self
+
+    def getRegParam(self):
+        """
+        Gets the value of regParam or its default value.
+        """
+        if self.regParam in self.paramMap:
+            return self.paramMap[self.regParam]
+        else:
+            return self.regParam.defaultValue
+
+
+class HasFeaturesCol(Params):
+    """
+    Params with featuresCol.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features')
+
+    def __init__(self):
+        super(HasFeaturesCol, self).__init__()
+        #: param for features column name
+        self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
+
+    def setFeaturesCol(self, value):
+        """
+        Sets the value of :py:attr:`featuresCol`.
+        """
+        self.paramMap[self.featuresCol] = value
+        return self
+
+    def getFeaturesCol(self):
+        """
+        Gets the value of featuresCol or its default value.
+        """
+        if self.featuresCol in self.paramMap:
+            return self.paramMap[self.featuresCol]
+        else:
+            return self.featuresCol.defaultValue
+
+
+class HasLabelCol(Params):
+    """
+    Params with labelCol.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label')
+
+    def __init__(self):
+        super(HasLabelCol, self).__init__()
+        #: param for label column name
+        self.labelCol = Param(self, "labelCol", "label column name", 'label')
+
+    def setLabelCol(self, value):
+        """
+        Sets the value of :py:attr:`labelCol`.
+        """
+        self.paramMap[self.labelCol] = value
+        return self
+
+    def getLabelCol(self):
+        """
+        Gets the value of labelCol or its default value.
+        """
+        if self.labelCol in self.paramMap:
+            return self.paramMap[self.labelCol]
+        else:
+            return self.labelCol.defaultValue
+
+
+class HasPredictionCol(Params):
+    """
+    Params with predictionCol.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction')
+
+    def __init__(self):
+        super(HasPredictionCol, self).__init__()
+        #: param for prediction column name
+        self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
+
+    def setPredictionCol(self, value):
+        """
+        Sets the value of :py:attr:`predictionCol`.
+        """
+        self.paramMap[self.predictionCol] = value
+        return self
+
+    def getPredictionCol(self):
+        """
+        Gets the value of predictionCol or its default value.
+        """
+        if self.predictionCol in self.paramMap:
+            return self.paramMap[self.predictionCol]
+        else:
+            return self.predictionCol.defaultValue
+
+
+class HasInputCol(Params):
+    """
+    Params with inputCol.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input')
+
+    def __init__(self):
+        super(HasInputCol, self).__init__()
+        #: param for input column name
+        self.inputCol = Param(self, "inputCol", "input column name", 'input')
+
+    def setInputCol(self, value):
+        """
+        Sets the value of :py:attr:`inputCol`.
+        """
+        self.paramMap[self.inputCol] = value
+        return self
+
+    def getInputCol(self):
+        """
+        Gets the value of inputCol or its default value.
+        """
+        if self.inputCol in self.paramMap:
+            return self.paramMap[self.inputCol]
+        else:
+            return self.inputCol.defaultValue
+
+
+class HasOutputCol(Params):
+    """
+    Params with outputCol.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output')
+
+    def __init__(self):
+        super(HasOutputCol, self).__init__()
+        #: param for output column name
+        self.outputCol = Param(self, "outputCol", "output column name", 'output')
+
+    def setOutputCol(self, value):
+        """
+        Sets the value of :py:attr:`outputCol`.
+        """
+        self.paramMap[self.outputCol] = value
+        return self
+
+    def getOutputCol(self):
+        """
+        Gets the value of outputCol or its default value.
+        """
+        if self.outputCol in self.paramMap:
+            return self.paramMap[self.outputCol]
+        else:
+            return self.outputCol.defaultValue
+
+
+class HasNumFeatures(Params):
+    """
+    Params with numFeatures.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
+
+    def __init__(self):
+        super(HasNumFeatures, self).__init__()
+        #: param for number of features
+        self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
+
+    def setNumFeatures(self, value):
+        """
+        Sets the value of :py:attr:`numFeatures`.
+        """
+        self.paramMap[self.numFeatures] = value
+        return self
+
+    def getNumFeatures(self):
+        """
+        Gets the value of numFeatures or its default value.
+        """
+        if self.numFeatures in self.paramMap:
+            return self.paramMap[self.numFeatures]
+        else:
+            return self.numFeatures.defaultValue

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
new file mode 100644
index 0000000..2d239f8
--- /dev/null
+++ b/python/pyspark/ml/pipeline.py
@@ -0,0 +1,154 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark.ml.param import Param, Params
+from pyspark.ml.util import inherit_doc
+
+
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
+
+
+@inherit_doc
+class Estimator(Params):
+    """
+    Abstract class for estimators that fit models to data.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def fit(self, dataset, params={}):
+        """
+        Fits a model to the input dataset with optional parameters.
+
+        :param dataset: input dataset, which is an instance of
+                        :py:class:`pyspark.sql.SchemaRDD`
+        :param params: an optional param map that overwrites embedded
+                       params
+        :returns: fitted model
+        """
+        raise NotImplementedError()
+
+
+@inherit_doc
+class Transformer(Params):
+    """
+    Abstract class for transformers that transform one dataset into
+    another.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def transform(self, dataset, params={}):
+        """
+        Transforms the input dataset with optional parameters.
+
+        :param dataset: input dataset, which is an instance of
+                        :py:class:`pyspark.sql.SchemaRDD`
+        :param params: an optional param map that overwrites embedded
+                       params
+        :returns: transformed dataset
+        """
+        raise NotImplementedError()
+
+
+@inherit_doc
+class Pipeline(Estimator):
+    """
+    A simple pipeline, which acts as an estimator. A Pipeline consists
+    of a sequence of stages, each of which is either an
+    :py:class:`Estimator` or a :py:class:`Transformer`. When
+    :py:meth:`Pipeline.fit` is called, the stages are executed in
+    order. If a stage is an :py:class:`Estimator`, its
+    :py:meth:`Estimator.fit` method will be called on the input
+    dataset to fit a model. Then the model, which is a transformer,
+    will be used to transform the dataset as the input to the next
+    stage. If a stage is a :py:class:`Transformer`, its
+    :py:meth:`Transformer.transform` method will be called to produce
+    the dataset for the next stage. The fitted model from a
+    :py:class:`Pipeline` is an :py:class:`PipelineModel`, which
+    consists of fitted models and transformers, corresponding to the
+    pipeline stages. If there are no stages, the pipeline acts as an
+    identity transformer.
+    """
+
+    def __init__(self):
+        super(Pipeline, self).__init__()
+        #: Param for pipeline stages.
+        self.stages = Param(self, "stages", "pipeline stages")
+
+    def setStages(self, value):
+        """
+        Set pipeline stages.
+        :param value: a list of transformers or estimators
+        :return: the pipeline instance
+        """
+        self.paramMap[self.stages] = value
+        return self
+
+    def getStages(self):
+        """
+        Get pipeline stages.
+        """
+        if self.stages in self.paramMap:
+            return self.paramMap[self.stages]
+
+    def fit(self, dataset, params={}):
+        paramMap = self._merge_params(params)
+        stages = paramMap[self.stages]
+        for stage in stages:
+            if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
+                raise ValueError(
+                    "Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
+        indexOfLastEstimator = -1
+        for i, stage in enumerate(stages):
+            if isinstance(stage, Estimator):
+                indexOfLastEstimator = i
+        transformers = []
+        for i, stage in enumerate(stages):
+            if i <= indexOfLastEstimator:
+                if isinstance(stage, Transformer):
+                    transformers.append(stage)
+                    dataset = stage.transform(dataset, paramMap)
+                else:  # must be an Estimator
+                    model = stage.fit(dataset, paramMap)
+                    transformers.append(model)
+                    if i < indexOfLastEstimator:
+                        dataset = model.transform(dataset, paramMap)
+            else:
+                transformers.append(stage)
+        return PipelineModel(transformers)
+
+
+@inherit_doc
+class PipelineModel(Transformer):
+    """
+    Represents a compiled pipeline with transformers and fitted models.
+    """
+
+    def __init__(self, transformers):
+        super(PipelineModel, self).__init__()
+        self.transformers = transformers
+
+    def transform(self, dataset, params={}):
+        paramMap = self._merge_params(params)
+        for t in self.transformers:
+            dataset = t.transform(dataset, paramMap)
+        return dataset

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
new file mode 100644
index 0000000..b627c2b
--- /dev/null
+++ b/python/pyspark/ml/tests.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for Spark ML Python APIs.
+"""
+
+import sys
+
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Param
+from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
+
+
+class MockDataset(DataFrame):
+
+    def __init__(self):
+        self.index = 0
+
+
+class MockTransformer(Transformer):
+
+    def __init__(self):
+        super(MockTransformer, self).__init__()
+        self.fake = Param(self, "fake", "fake", None)
+        self.dataset_index = None
+        self.fake_param_value = None
+
+    def transform(self, dataset, params={}):
+        self.dataset_index = dataset.index
+        if self.fake in params:
+            self.fake_param_value = params[self.fake]
+        dataset.index += 1
+        return dataset
+
+
+class MockEstimator(Estimator):
+
+    def __init__(self):
+        super(MockEstimator, self).__init__()
+        self.fake = Param(self, "fake", "fake", None)
+        self.dataset_index = None
+        self.fake_param_value = None
+        self.model = None
+
+    def fit(self, dataset, params={}):
+        self.dataset_index = dataset.index
+        if self.fake in params:
+            self.fake_param_value = params[self.fake]
+        model = MockModel()
+        self.model = model
+        return model
+
+
+class MockModel(MockTransformer, Transformer):
+
+    def __init__(self):
+        super(MockModel, self).__init__()
+
+
+class PipelineTests(PySparkTestCase):
+
+    def test_pipeline(self):
+        dataset = MockDataset()
+        estimator0 = MockEstimator()
+        transformer1 = MockTransformer()
+        estimator2 = MockEstimator()
+        transformer3 = MockTransformer()
+        pipeline = Pipeline() \
+            .setStages([estimator0, transformer1, estimator2, transformer3])
+        pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
+        self.assertEqual(0, estimator0.dataset_index)
+        self.assertEqual(0, estimator0.fake_param_value)
+        model0 = estimator0.model
+        self.assertEqual(0, model0.dataset_index)
+        self.assertEqual(1, transformer1.dataset_index)
+        self.assertEqual(1, transformer1.fake_param_value)
+        self.assertEqual(2, estimator2.dataset_index)
+        model2 = estimator2.model
+        self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
+                                                "not be called during fit.")
+        dataset = pipeline_model.transform(dataset)
+        self.assertEqual(2, model0.dataset_index)
+        self.assertEqual(3, transformer1.dataset_index)
+        self.assertEqual(4, model2.dataset_index)
+        self.assertEqual(5, transformer3.dataset_index)
+        self.assertEqual(6, dataset.index)
+
+
+if __name__ == "__main__":
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
new file mode 100644
index 0000000..b1caa84
--- /dev/null
+++ b/python/pyspark/ml/util.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import uuid
+
+
+def inherit_doc(cls):
+    for name, func in vars(cls).items():
+        # only inherit docstring for public functions
+        if name.startswith("_"):
+            continue
+        if not func.__doc__:
+            for parent in cls.__bases__:
+                parent_func = getattr(parent, name, None)
+                if parent_func and getattr(parent_func, "__doc__", None):
+                    func.__doc__ = parent_func.__doc__
+                    break
+    return cls
+
+
+class Identifiable(object):
+    """
+    Object with a unique ID.
+    """
+
+    def __init__(self):
+        #: A unique id for the object. The default implementation
+        #: concatenates the class name, "-", and 8 random hex chars.
+        self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]
+
+    def __repr__(self):
+        return self.uid

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
new file mode 100644
index 0000000..9e12ddc
--- /dev/null
+++ b/python/pyspark/ml/wrapper.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark import SparkContext
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Params
+from pyspark.ml.pipeline import Estimator, Transformer
+from pyspark.ml.util import inherit_doc
+
+
+def _jvm():
+    """
+    Returns the JVM view associated with SparkContext. Must be called
+    after SparkContext is initialized.
+    """
+    jvm = SparkContext._jvm
+    if jvm:
+        return jvm
+    else:
+        raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
+
+
+@inherit_doc
+class JavaWrapper(Params):
+    """
+    Utility class to help create wrapper classes from Java/Scala
+    implementations of pipeline components.
+    """
+
+    __metaclass__ = ABCMeta
+
+    #: Fully-qualified class name of the wrapped Java component.
+    _java_class = None
+
+    def _java_obj(self):
+        """
+        Returns or creates a Java object.
+        """
+        java_obj = _jvm()
+        for name in self._java_class.split("."):
+            java_obj = getattr(java_obj, name)
+        return java_obj()
+
+    def _transfer_params_to_java(self, params, java_obj):
+        """
+        Transforms the embedded params and additional params to the
+        input Java object.
+        :param params: additional params (overwriting embedded values)
+        :param java_obj: Java object to receive the params
+        """
+        paramMap = self._merge_params(params)
+        for param in self.params:
+            if param in paramMap:
+                java_obj.set(param.name, paramMap[param])
+
+    def _empty_java_param_map(self):
+        """
+        Returns an empty Java ParamMap reference.
+        """
+        return _jvm().org.apache.spark.ml.param.ParamMap()
+
+    def _create_java_param_map(self, params, java_obj):
+        paramMap = self._empty_java_param_map()
+        for param, value in params.items():
+            if param.parent is self:
+                paramMap.put(java_obj.getParam(param.name), value)
+        return paramMap
+
+
+@inherit_doc
+class JavaEstimator(Estimator, JavaWrapper):
+    """
+    Base class for :py:class:`Estimator`s that wrap Java/Scala
+    implementations.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def _create_model(self, java_model):
+        """
+        Creates a model from the input Java model reference.
+        """
+        return JavaModel(java_model)
+
+    def _fit_java(self, dataset, params={}):
+        """
+        Fits a Java model to the input dataset.
+        :param dataset: input dataset, which is an instance of
+                        :py:class:`pyspark.sql.SchemaRDD`
+        :param params: additional params (overwriting embedded values)
+        :return: fitted Java model
+        """
+        java_obj = self._java_obj()
+        self._transfer_params_to_java(params, java_obj)
+        return java_obj.fit(dataset._jdf, self._empty_java_param_map())
+
+    def fit(self, dataset, params={}):
+        java_model = self._fit_java(dataset, params)
+        return self._create_model(java_model)
+
+
+@inherit_doc
+class JavaTransformer(Transformer, JavaWrapper):
+    """
+    Base class for :py:class:`Transformer`s that wrap Java/Scala
+    implementations.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def transform(self, dataset, params={}):
+        java_obj = self._java_obj()
+        self._transfer_params_to_java({}, java_obj)
+        java_param_map = self._create_java_param_map(params, java_obj)
+        return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
+                         dataset.sql_ctx)
+
+
+@inherit_doc
+class JavaModel(JavaTransformer):
+    """
+    Base class for :py:class:`Model`s that wrap Java/Scala
+    implementations.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self, java_model):
+        super(JavaTransformer, self).__init__()
+        self._java_model = java_model
+
+    def _java_obj(self):
+        return self._java_model

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 7d7550c..c3a6938 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1794,20 +1794,6 @@ class Row(tuple):
             return "<Row(%s)>" % ", ".join(self)
 
 
-def inherit_doc(cls):
-    for name, func in vars(cls).items():
-        # only inherit docstring for public functions
-        if name.startswith("_"):
-            continue
-        if not func.__doc__:
-            for parent in cls.__bases__:
-                parent_func = getattr(parent, name, None)
-                if parent_func and getattr(parent_func, "__doc__", None):
-                    func.__doc__ = parent_func.__doc__
-                    break
-    return cls
-
-
 class DataFrame(object):
 
     """A collection of rows that have the same columns.

http://git-wip-us.apache.org/repos/asf/spark/blob/e80dc1c5/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index 53c3455..84cb89b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -82,6 +82,13 @@ function run_mllib_tests() {
     run_test "pyspark/mllib/tests.py"
 }
 
+function run_ml_tests() {
+    echo "Run ml tests ..."
+    run_test "pyspark/ml/feature.py"
+    run_test "pyspark/ml/classification.py"
+    run_test "pyspark/ml/tests.py"
+}
+
 function run_streaming_tests() {
     echo "Run streaming tests ..."
     run_test "pyspark/streaming/util.py"
@@ -103,6 +110,7 @@ $PYSPARK_PYTHON --version
 run_core_tests
 run_sql_tests
 run_mllib_tests
+run_ml_tests
 run_streaming_tests
 
 # Try to test with PyPy


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org