You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/17 17:11:13 UTC
spark git commit: [SPARK-21741][ML][PYSPARK] Python API for
DataFrame-based multivariate summarizer
Repository: spark
Updated Branches:
refs/heads/master f39e82ce1 -> 1ca3c50fe
[SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer
## What changes were proposed in this pull request?
Python API for DataFrame-based multivariate summarizer.
## How was this patch tested?
doctest added.
Author: WeichenXu <we...@databricks.com>
Closes #20695 from WeichenXu123/py_summarizer.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ca3c50f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ca3c50f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ca3c50f
Branch: refs/heads/master
Commit: 1ca3c50fefb34532c78427fa74872db3ecbf7ba2
Parents: f39e82c
Author: WeichenXu <we...@databricks.com>
Authored: Tue Apr 17 10:11:08 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Apr 17 10:11:08 2018 -0700
----------------------------------------------------------------------
python/pyspark/ml/stat.py | 193 ++++++++++++++++++++++++++++++++++++++++-
1 file changed, 192 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1ca3c50f/python/pyspark/ml/stat.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 93d0f4f..a06ab31 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -19,7 +19,9 @@ import sys
from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.wrapper import _jvm
+from pyspark.ml.wrapper import JavaWrapper, _jvm
+from pyspark.sql.column import Column, _to_seq
+from pyspark.sql.functions import lit
class ChiSquareTest(object):
@@ -195,6 +197,195 @@ class KolmogorovSmirnovTest(object):
_jvm().PythonUtils.toSeq(params)))
+class Summarizer(object):
+ """
+ .. note:: Experimental
+
+ Tools for vectorized statistics on MLlib Vectors.
+ The methods in this package provide various statistics for Vectors contained inside DataFrames.
+ This class lets users pick the statistics they would like to extract for a given column.
+
+ >>> from pyspark.ml.stat import Summarizer
+ >>> from pyspark.sql import Row
+ >>> from pyspark.ml.linalg import Vectors
+ >>> summarizer = Summarizer.metrics("mean", "count")
+ >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
+ ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
+ >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+ +-----------------------------------+
+ |aggregate_metrics(features, weight)|
+ +-----------------------------------+
+ |[[1.0,1.0,1.0], 1] |
+ +-----------------------------------+
+ <BLANKLINE>
+ >>> df.select(summarizer.summary(df.features)).show(truncate=False)
+ +--------------------------------+
+ |aggregate_metrics(features, 1.0)|
+ +--------------------------------+
+ |[[1.0,1.5,2.0], 2] |
+ +--------------------------------+
+ <BLANKLINE>
+ >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.0,1.0] |
+ +--------------+
+ <BLANKLINE>
+ >>> df.select(Summarizer.mean(df.features)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.5,2.0] |
+ +--------------+
+ <BLANKLINE>
+
+ .. versionadded:: 2.4.0
+
+ """
+ @staticmethod
+ @since("2.4.0")
+ def mean(col, weightCol=None):
+ """
+ return a column of mean summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "mean")
+
+ @staticmethod
+ @since("2.4.0")
+ def variance(col, weightCol=None):
+ """
+ return a column of variance summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "variance")
+
+ @staticmethod
+ @since("2.4.0")
+ def count(col, weightCol=None):
+ """
+ return a column of count summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "count")
+
+ @staticmethod
+ @since("2.4.0")
+ def numNonZeros(col, weightCol=None):
+ """
+ return a column of numNonZero summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
+
+ @staticmethod
+ @since("2.4.0")
+ def max(col, weightCol=None):
+ """
+ return a column of max summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "max")
+
+ @staticmethod
+ @since("2.4.0")
+ def min(col, weightCol=None):
+ """
+ return a column of min summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "min")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL1(col, weightCol=None):
+ """
+ return a column of normL1 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL1")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL2(col, weightCol=None):
+ """
+ return a column of normL2 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL2")
+
+ @staticmethod
+ def _check_param(featuresCol, weightCol):
+ if weightCol is None:
+ weightCol = lit(1.0)
+ if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column):
+ raise TypeError("featureCol and weightCol should be a Column")
+ return featuresCol, weightCol
+
+ @staticmethod
+ def _get_single_metric(col, weightCol, metric):
+ col, weightCol = Summarizer._check_param(col, weightCol)
+ return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric,
+ col._jc, weightCol._jc))
+
+ @staticmethod
+ @since("2.4.0")
+ def metrics(*metrics):
+ """
+ Given a list of metrics, provides a builder that it turns computes metrics from a column.
+
+ See the documentation of [[Summarizer]] for an example.
+
+ The following metrics are accepted (case sensitive):
+ - mean: a vector that contains the coefficient-wise mean.
+ - variance: a vector tha contains the coefficient-wise variance.
+ - count: the count of all vectors seen.
+ - numNonzeros: a vector with the number of non-zeros for each coefficients
+ - max: the maximum for each coefficient.
+ - min: the minimum for each coefficient.
+ - normL2: the Euclidian norm for each coefficient.
+ - normL1: the L1 norm of each coefficient (sum of the absolute values).
+
+ :param metrics:
+ metrics that can be provided.
+ :return:
+ an object of :py:class:`pyspark.ml.stat.SummaryBuilder`
+
+ Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
+ interface.
+ """
+ sc = SparkContext._active_spark_context
+ js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
+ _to_seq(sc, metrics))
+ return SummaryBuilder(js)
+
+
+class SummaryBuilder(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ A builder object that provides summary statistics about a given column.
+
+ Users should not directly create such builders, but instead use one of the methods in
+ :py:class:`pyspark.ml.stat.Summarizer`
+
+ .. versionadded:: 2.4.0
+
+ """
+ def __init__(self, jSummaryBuilder):
+ super(SummaryBuilder, self).__init__(jSummaryBuilder)
+
+ @since("2.4.0")
+ def summary(self, featuresCol, weightCol=None):
+ """
+ Returns an aggregate object that contains the summary of the column with the requested
+ metrics.
+
+ :param featuresCol:
+ a column that contains features Vector object.
+ :param weightCol:
+ a column that contains weight value. Default weight is 1.0.
+ :return:
+ an aggregate column that contains the statistics. The exact content of this
+ structure is determined during the creation of the builder.
+ """
+ featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol)
+ return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.stat
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org