You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/03/16 10:49:35 UTC

spark git commit: [SPARK-13568][ML] Create feature transformer to impute missing values

Repository: spark
Updated Branches:
  refs/heads/master 1472cac4b -> d647aae27


[SPARK-13568][ML] Create feature transformer to impute missing values

## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-13568
It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn.
Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc).

Currently this PR supports imputation for Double and Vector (null and NaN in Vector).
## How was this patch tested?

new unit tests and manual test

Author: Yuhao Yang <hh...@gmail.com>
Author: Yuhao Yang <yu...@intel.com>
Author: Yuhao <yu...@intel.com>

Closes #11601 from hhbyyh/imputer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d647aae2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d647aae2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d647aae2

Branch: refs/heads/master
Commit: d647aae278ef31a07fc64715eb07e48294d94bb8
Parents: 1472cac
Author: Yuhao Yang <hh...@gmail.com>
Authored: Thu Mar 16 12:49:59 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu Mar 16 12:49:59 2017 +0200

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Imputer.scala   | 259 +++++++++++++++++++
 .../apache/spark/ml/feature/ImputerSuite.scala  | 185 +++++++++++++
 2 files changed, 444 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d647aae2/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
new file mode 100644
index 0000000..b1a802e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.HasInputCols
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * Params for [[Imputer]] and [[ImputerModel]].
+ */
+private[feature] trait ImputerParams extends Params with HasInputCols {
+
+  /**
+   * The imputation strategy.
+   * If "mean", then replace missing values using the mean value of the feature.
+   * If "median", then replace missing values using the approximate median value of the feature.
+   * Default: mean
+   *
+   * @group param
+   */
+  final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
+    s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
+    s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
+    ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
+
+  /** @group getParam */
+  def getStrategy: String = $(strategy)
+
+  /**
+   * The placeholder for the missing values. All occurrences of missingValue will be imputed.
+   * Note that null values are always treated as missing.
+   * Default: Double.NaN
+   *
+   * @group param
+   */
+  final val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
+    "The placeholder for the missing values. All occurrences of missingValue will be imputed")
+
+  /** @group getParam */
+  def getMissingValue: Double = $(missingValue)
+
+  /**
+   * Param for output column names.
+   * @group param
+   */
+  final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
+    "output column names")
+
+  /** @group getParam */
+  final def getOutputCols: Array[String] = $(outputCols)
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
+      s" duplicates: (${$(inputCols).mkString(", ")})")
+    require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +
+      s" duplicates: (${$(outputCols).mkString(", ")})")
+    require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
+      s" and outputCols(${$(outputCols).length}) should have the same length")
+    val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
+      val inputField = schema(inputCol)
+      SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
+      StructField(outputCol, inputField.dataType, inputField.nullable)
+    }
+    StructType(schema ++ outputFields)
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Imputation estimator for completing missing values, either using the mean or the median
+ * of the column in which the missing values are located. The input column should be of
+ * DoubleType or FloatType. Currently Imputer does not support categorical features yet
+ * (SPARK-15041) and possibly creates incorrect values for a categorical feature.
+ *
+ * Note that the mean/median value is computed after filtering out missing values.
+ * All Null values in the input column are treated as missing, and so are also imputed. For
+ * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
+ */
+@Experimental
+class Imputer @Since("2.2.0")(override val uid: String)
+  extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {
+
+  @Since("2.2.0")
+  def this() = this(Identifiable.randomUID("imputer"))
+
+  /** @group setParam */
+  @Since("2.2.0")
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  @Since("2.2.0")
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  /**
+   * Imputation strategy. Available options are ["mean", "median"].
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setStrategy(value: String): this.type = set(strategy, value)
+
+  /** @group setParam */
+  @Since("2.2.0")
+  def setMissingValue(value: Double): this.type = set(missingValue, value)
+
+  setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)
+
+  override def fit(dataset: Dataset[_]): ImputerModel = {
+    transformSchema(dataset.schema, logging = true)
+    val spark = dataset.sparkSession
+    import spark.implicits._
+    val surrogates = $(inputCols).map { inputCol =>
+      val ic = col(inputCol)
+      val filtered = dataset.select(ic.cast(DoubleType))
+        .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
+      if(filtered.take(1).length == 0) {
+        throw new SparkException(s"surrogate cannot be computed. " +
+          s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
+      }
+      val surrogate = $(strategy) match {
+        case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
+        case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
+      }
+      surrogate
+    }
+
+    val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
+    val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
+    val surrogateDF = spark.createDataFrame(rows, schema)
+    copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): Imputer = defaultCopy(extra)
+}
+
+@Since("2.2.0")
+object Imputer extends DefaultParamsReadable[Imputer] {
+
+  /** strategy names that Imputer currently supports. */
+  private[ml] val mean = "mean"
+  private[ml] val median = "median"
+
+  @Since("2.2.0")
+  override def load(path: String): Imputer = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[Imputer]].
+ *
+ * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are
+ *                    used to replace the missing values in the input DataFrame.
+ */
+@Experimental
+class ImputerModel private[ml](
+    override val uid: String,
+    val surrogateDF: DataFrame)
+  extends Model[ImputerModel] with ImputerParams with MLWritable {
+
+  import ImputerModel._
+
+  /** @group setParam */
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+    var outputDF = dataset
+    val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
+
+    $(inputCols).zip($(outputCols)).zip(surrogates).foreach {
+      case ((inputCol, outputCol), surrogate) =>
+        val inputType = dataset.schema(inputCol).dataType
+        val ic = col(inputCol)
+        outputDF = outputDF.withColumn(outputCol,
+          when(ic.isNull, surrogate)
+          .when(ic === $(missingValue), surrogate)
+          .otherwise(ic)
+          .cast(inputType))
+    }
+    outputDF.toDF()
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): ImputerModel = {
+    val copied = new ImputerModel(uid, surrogateDF)
+    copyValues(copied, extra).setParent(parent)
+  }
+
+  @Since("2.2.0")
+  override def write: MLWriter = new ImputerModelWriter(this)
+}
+
+
+@Since("2.2.0")
+object ImputerModel extends MLReadable[ImputerModel] {
+
+  private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val dataPath = new Path(path, "data").toString
+      instance.surrogateDF.repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class ImputerReader extends MLReader[ImputerModel] {
+
+    private val className = classOf[ImputerModel].getName
+
+    override def load(path: String): ImputerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val surrogateDF = sqlContext.read.parquet(dataPath)
+      val model = new ImputerModel(metadata.uid, surrogateDF)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("2.2.0")
+  override def read: MLReader[ImputerModel] = new ImputerReader
+
+  @Since("2.2.0")
+  override def load(path: String): ImputerModel = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d647aae2/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
new file mode 100644
index 0000000..ee2ba73
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row}
+
+class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  test("Imputer for Double with default missing Value NaN") {
+    val df = spark.createDataFrame( Seq(
+      (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0),
+      (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0),
+      (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0),
+      (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0)
+    )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1",
+      "expected_mean_value2", "expected_median_value2")
+    val imputer = new Imputer()
+      .setInputCols(Array("value1", "value2"))
+      .setOutputCols(Array("out1", "out2"))
+    ImputerSuite.iterateStrategyTest(imputer, df)
+  }
+
+  test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") {
+    val df = spark.createDataFrame( Seq(
+      (0, 1.0, 1.0, 1.0),
+      (1, 3.0, 3.0, 3.0),
+      (2, Double.NaN, Double.NaN, Double.NaN),
+      (3, -1.0, 2.0, 3.0)
+    )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+    val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
+      .setMissingValue(-1.0)
+    ImputerSuite.iterateStrategyTest(imputer, df)
+  }
+
+  test("Imputer for Float with missing Value -1.0") {
+    val df = spark.createDataFrame( Seq(
+      (0, 1.0F, 1.0F, 1.0F),
+      (1, 3.0F, 3.0F, 3.0F),
+      (2, 10.0F, 10.0F, 10.0F),
+      (3, 10.0F, 10.0F, 10.0F),
+      (4, -1.0F, 6.0F, 3.0F)
+    )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+    val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
+      .setMissingValue(-1)
+    ImputerSuite.iterateStrategyTest(imputer, df)
+  }
+
+  test("Imputer should impute null as well as 'missingValue'") {
+    val rawDf = spark.createDataFrame( Seq(
+      (0, 4.0, 4.0, 4.0),
+      (1, 10.0, 10.0, 10.0),
+      (2, 10.0, 10.0, 10.0),
+      (3, Double.NaN, 8.0, 10.0),
+      (4, -1.0, 8.0, 10.0)
+    )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value")
+    val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value")
+    val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
+    ImputerSuite.iterateStrategyTest(imputer, df)
+  }
+
+  test("Imputer throws exception when surrogate cannot be computed") {
+    val df = spark.createDataFrame( Seq(
+      (0, Double.NaN, 1.0, 1.0),
+      (1, Double.NaN, 3.0, 3.0),
+      (2, Double.NaN, Double.NaN, Double.NaN)
+    )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+    Seq("mean", "median").foreach { strategy =>
+      val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
+        .setStrategy(strategy)
+      withClue("Imputer should fail all the values are invalid") {
+        val e: SparkException = intercept[SparkException] {
+          val model = imputer.fit(df)
+        }
+        assert(e.getMessage.contains("surrogate cannot be computed"))
+      }
+    }
+  }
+
+  test("Imputer input & output column validation") {
+    val df = spark.createDataFrame( Seq(
+      (0, 1.0, 1.0, 1.0),
+      (1, Double.NaN, 3.0, 3.0),
+      (2, Double.NaN, Double.NaN, Double.NaN)
+    )).toDF("id", "value1", "value2", "value3")
+    Seq("mean", "median").foreach { strategy =>
+      withClue("Imputer should fail if inputCols and outputCols are different length") {
+        val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+          val imputer = new Imputer().setStrategy(strategy)
+            .setInputCols(Array("value1", "value2"))
+            .setOutputCols(Array("out1"))
+          val model = imputer.fit(df)
+        }
+        assert(e.getMessage.contains("should have the same length"))
+      }
+
+      withClue("Imputer should fail if inputCols contains duplicates") {
+        val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+          val imputer = new Imputer().setStrategy(strategy)
+            .setInputCols(Array("value1", "value1"))
+            .setOutputCols(Array("out1", "out2"))
+          val model = imputer.fit(df)
+        }
+        assert(e.getMessage.contains("inputCols contains duplicates"))
+      }
+
+      withClue("Imputer should fail if outputCols contains duplicates") {
+        val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+          val imputer = new Imputer().setStrategy(strategy)
+            .setInputCols(Array("value1", "value2"))
+            .setOutputCols(Array("out1", "out1"))
+          val model = imputer.fit(df)
+        }
+        assert(e.getMessage.contains("outputCols contains duplicates"))
+      }
+    }
+  }
+
+  test("Imputer read/write") {
+    val t = new Imputer()
+      .setInputCols(Array("myInputCol"))
+      .setOutputCols(Array("myOutputCol"))
+      .setMissingValue(-1.0)
+    testDefaultReadWrite(t)
+  }
+
+  test("ImputerModel read/write") {
+    val spark = this.spark
+    import spark.implicits._
+    val surrogateDF = Seq(1.234).toDF("myInputCol")
+
+    val instance = new ImputerModel(
+      "myImputer", surrogateDF)
+      .setInputCols(Array("myInputCol"))
+      .setOutputCols(Array("myOutputCol"))
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns)
+    assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
+  }
+
+}
+
+object ImputerSuite {
+
+  /**
+   * Imputation strategy. Available options are ["mean", "median"].
+   * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median"
+   */
+  def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = {
+    val inputCols = imputer.getInputCols
+
+    Seq("mean", "median").foreach { strategy =>
+      imputer.setStrategy(strategy)
+      val model = imputer.fit(df)
+      val resultDF = model.transform(df)
+      imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) =>
+        resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach {
+          case Row(exp: Float, out: Float) =>
+            assert((exp.isNaN && out.isNaN) || (exp == out),
+              s"Imputed values differ. Expected: $exp, actual: $out")
+          case Row(exp: Double, out: Double) =>
+            assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
+              s"Imputed values differ. Expected: $exp, actual: $out")
+        }
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org