You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/09/13 12:12:27 UTC
spark git commit: [SPARK-21690][ML] one-pass imputer
Repository: spark
Updated Branches:
refs/heads/master ca00cc70d -> 0fa5b7cac
[SPARK-21690][ML] one-pass imputer
## What changes were proposed in this pull request?
parallelize the computation of all columns
performance tests:
|numColums| Mean(Old) | Median(Old) | Mean(RDD) | Median(RDD) | Mean(DF) | Median(DF) |
|------|----------|------------|----------|------------|----------|------------|
|1|0.0771394713|0.0658712813|0.080779802|0.048165981499999996|0.10525509870000001|0.0499620203|
|10|0.7234340630999999|0.5954440414|0.0867935197|0.13263428659999998|0.09255724889999999|0.1573943635|
|100|7.3756451568|6.2196631259|0.1911931552|0.8625376817000001|0.5557462431|1.7216837982000002|
## How was this patch tested?
existing tests
Author: Zheng RuiFeng <ru...@foxmail.com>
Closes #18902 from zhengruifeng/parallelize_imputer.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fa5b7ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fa5b7ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fa5b7ca
Branch: refs/heads/master
Commit: 0fa5b7cacca4e867dd9f787cc2801616967932a4
Parents: ca00cc7
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Wed Sep 13 20:12:21 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Wed Sep 13 20:12:21 2017 +0800
----------------------------------------------------------------------
.../org/apache/spark/ml/feature/Imputer.scala | 56 ++++++++++++++------
1 file changed, 41 insertions(+), 15 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0fa5b7ca/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 9e023b9..1f36ece 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
- import spark.implicits._
- val surrogates = $(inputCols).map { inputCol =>
- val ic = col(inputCol)
- val filtered = dataset.select(ic.cast(DoubleType))
- .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
- if(filtered.take(1).length == 0) {
- throw new SparkException(s"surrogate cannot be computed. " +
- s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
- }
- val surrogate = $(strategy) match {
- case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
- case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
- }
- surrogate
+
+ val cols = $(inputCols).map { inputCol =>
+ when(col(inputCol).equalTo($(missingValue)), null)
+ .when(col(inputCol).isNaN, null)
+ .otherwise(col(inputCol))
+ .cast("double")
+ .as(inputCol)
+ }
+
+ val results = $(strategy) match {
+ case Imputer.mean =>
+ // Function avg will ignore null automatically.
+ // For a column only containing null, avg will return null.
+ val row = dataset.select(cols.map(avg): _*).head()
+ Array.range(0, $(inputCols).length).map { i =>
+ if (row.isNullAt(i)) {
+ Double.NaN
+ } else {
+ row.getDouble(i)
+ }
+ }
+
+ case Imputer.median =>
+ // Function approxQuantile will ignore null automatically.
+ // For a column only containing null, approxQuantile will return an empty array.
+ dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
+ .map { array =>
+ if (array.isEmpty) {
+ Double.NaN
+ } else {
+ array.head
+ }
+ }
+ }
+
+ val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
+ if (emptyCols.nonEmpty) {
+ throw new SparkException(s"surrogate cannot be computed. " +
+ s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
+ s"missingValue(${$(missingValue)})")
}
- val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
+ val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org