You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/15 03:14:03 UTC
spark git commit: [SPARK-7648] [MLLIB] Add weights and intercept to
GLM wrappers in spark.ml
Repository: spark
Updated Branches:
refs/heads/master b208f998b -> 723853eda
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz
Author: Xiangrui Meng <me...@databricks.com>
Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:
1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/723853ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/723853ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/723853ed
Branch: refs/heads/master
Commit: 723853edab18d28515af22097b76e4e6574b228e
Parents: b208f99
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu May 14 18:13:58 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu May 14 18:13:58 2015 -0700
----------------------------------------------------------------------
python/pyspark/ml/classification.py | 18 ++++++++++++++++++
python/pyspark/ml/regression.py | 18 ++++++++++++++++++
python/pyspark/ml/wrapper.py | 8 +++++++-
3 files changed, 43 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/723853ed/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 96d2905..8c9a55e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> model.transform(test0).head().prediction
0.0
+ >>> model.weights
+ DenseVector([5.5...])
+ >>> model.intercept
+ -2.68...
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> model.transform(test1).head().prediction
1.0
@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
Model fitted by LogisticRegression.
"""
+ @property
+ def weights(self):
+ """
+ Model weights.
+ """
+ return self._call_java("weights")
+
+ @property
+ def intercept(self):
+ """
+ Model intercept.
+ """
+ return self._call_java("intercept")
+
class TreeClassifierParams(object):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/723853ed/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 0ab5c6c..2803864 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
-1.0
+ >>> model.weights
+ DenseVector([1.0])
+ >>> model.intercept
+ 0.0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
Model fitted by LinearRegression.
"""
+ @property
+ def weights(self):
+ """
+ Model weights.
+ """
+ return self._call_java("weights")
+
+ @property
+ def intercept(self):
+ """
+ Model intercept.
+ """
+ return self._call_java("intercept")
+
class TreeRegressorParams(object):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/723853ed/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f5ac2a3..dda6c6a 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,7 +21,7 @@ from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
-from pyspark.mllib.common import inherit_doc
+from pyspark.mllib.common import inherit_doc, _java2py, _py2java
def _jvm():
@@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer):
def _java_obj(self):
return self._java_model
+ def _call_java(self, name, *args):
+ m = getattr(self._java_model, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
+
@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org