You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/21 03:40:41 UTC
git commit: [SPARK-2552][MLLIB] stabilize logistic function in pyspark
Repository: spark
Updated Branches:
refs/heads/master 9564f8548 -> b86db517b
[SPARK-2552][MLLIB] stabilize logistic function in pyspark
to avoid overflow in `exp(x)` if `x` is large.
Author: Xiangrui Meng <me...@databricks.com>
Closes #1493 from mengxr/py-logistic and squashes the following commits:
259e863 [Xiangrui Meng] stabilize logistic function in pyspark
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b86db517
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b86db517
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b86db517
Branch: refs/heads/master
Commit: b86db517b6a2795f687211205b6a14c8685873eb
Parents: 9564f85
Author: Xiangrui Meng <me...@databricks.com>
Authored: Sun Jul 20 18:40:36 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Jul 20 18:40:36 2014 -0700
----------------------------------------------------------------------
python/pyspark/mllib/classification.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b86db517/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 1c0c536..9e28dfb 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -63,7 +63,10 @@ class LogisticRegressionModel(LinearModel):
def predict(self, x):
_linear_predictor_typecheck(x, self._coeff)
margin = _dot(x, self._coeff) + self._intercept
- prob = 1/(1 + exp(-margin))
+ if margin > 0:
+ prob = 1 / (1 + exp(-margin))
+ else:
+ prob = 1 - 1 / (1 + exp(margin))
return 1 if prob > 0.5 else 0