You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/09/13 12:12:27 UTC

spark git commit: [SPARK-21690][ML] one-pass imputer

Repository: spark
Updated Branches:
  refs/heads/master ca00cc70d -> 0fa5b7cac


[SPARK-21690][ML] one-pass imputer

## What changes were proposed in this pull request?
parallelize the computation of all columns

performance tests:

|numColums| Mean(Old) | Median(Old) | Mean(RDD) | Median(RDD) | Mean(DF) | Median(DF) |
|------|----------|------------|----------|------------|----------|------------|
|1|0.0771394713|0.0658712813|0.080779802|0.048165981499999996|0.10525509870000001|0.0499620203|
|10|0.7234340630999999|0.5954440414|0.0867935197|0.13263428659999998|0.09255724889999999|0.1573943635|
|100|7.3756451568|6.2196631259|0.1911931552|0.8625376817000001|0.5557462431|1.7216837982000002|

## How was this patch tested?
existing tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #18902 from zhengruifeng/parallelize_imputer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fa5b7ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fa5b7ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fa5b7ca

Branch: refs/heads/master
Commit: 0fa5b7cacca4e867dd9f787cc2801616967932a4
Parents: ca00cc7
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Wed Sep 13 20:12:21 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Wed Sep 13 20:12:21 2017 +0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Imputer.scala   | 56 ++++++++++++++------
 1 file changed, 41 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0fa5b7ca/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 9e023b9..1f36ece 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
   override def fit(dataset: Dataset[_]): ImputerModel = {
     transformSchema(dataset.schema, logging = true)
     val spark = dataset.sparkSession
-    import spark.implicits._
-    val surrogates = $(inputCols).map { inputCol =>
-      val ic = col(inputCol)
-      val filtered = dataset.select(ic.cast(DoubleType))
-        .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
-      if(filtered.take(1).length == 0) {
-        throw new SparkException(s"surrogate cannot be computed. " +
-          s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
-      }
-      val surrogate = $(strategy) match {
-        case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
-        case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
-      }
-      surrogate
+
+    val cols = $(inputCols).map { inputCol =>
+      when(col(inputCol).equalTo($(missingValue)), null)
+        .when(col(inputCol).isNaN, null)
+        .otherwise(col(inputCol))
+        .cast("double")
+        .as(inputCol)
+    }
+
+    val results = $(strategy) match {
+      case Imputer.mean =>
+        // Function avg will ignore null automatically.
+        // For a column only containing null, avg will return null.
+        val row = dataset.select(cols.map(avg): _*).head()
+        Array.range(0, $(inputCols).length).map { i =>
+          if (row.isNullAt(i)) {
+            Double.NaN
+          } else {
+            row.getDouble(i)
+          }
+        }
+
+      case Imputer.median =>
+        // Function approxQuantile will ignore null automatically.
+        // For a column only containing null, approxQuantile will return an empty array.
+        dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
+          .map { array =>
+            if (array.isEmpty) {
+              Double.NaN
+            } else {
+              array.head
+            }
+          }
+    }
+
+    val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
+    if (emptyCols.nonEmpty) {
+      throw new SparkException(s"surrogate cannot be computed. " +
+        s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
+        s"missingValue(${$(missingValue)})")
     }
 
-    val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
+    val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
     val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
     val surrogateDF = spark.createDataFrame(rows, schema)
     copyValues(new ImputerModel(uid, surrogateDF).setParent(this))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org