You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/03/18 07:30:46 UTC

[GitHub] [spark] zhengruifeng commented on a change in pull request #27944: [SPARK-31180][ML] Implement PowerTransform

zhengruifeng commented on a change in pull request #27944: [SPARK-31180][ML] Implement PowerTransform
URL: https://github.com/apache/spark/pull/27944#discussion_r394149351
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/ml/feature/PowerTransform.scala
 ##########
 @@ -0,0 +1,561 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.commons.math3.analysis._
+import org.apache.commons.math3.optim._
+import org.apache.commons.math3.optim.nonlinear.scalar._
+import org.apache.commons.math3.optim.univariate._
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml._
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+
+/**
+ * Params for [[PowerTransform]] and [[PowerTransformModel]].
+ */
+private[feature] trait PowerTransformParams extends Params with HasInputCol with HasOutputCol {
+
+  /**
+   * The model type which is a string (case-sensitive).
+   * Supported options: "yeo-johnson", "box-cox".
+   * (default = yeo-johnson)
+   *
+   * @group param
+   */
+  final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
+    "which is a string (case-sensitive). Supported options: yeo-johnson (default), and box-cox.",
+    ParamValidators.inArray[String](PowerTransform.supportedModelTypes))
+
+  /** @group getParam */
+  final def getModelType: String = $(modelType)
+
+  setDefault(modelType -> PowerTransform.YeoJohnson)
+
+  /**
+   * param for number of bins to down-sample the curves in statistics computation.
+   * If 0, no down-sampling will occur.
+   * Default: 100,000.
+   * @group expertParam
+   */
+  val numBins: IntParam = new IntParam(this, "numBins", "Number of bins to down-sample " +
+    "the curves in statistics computation. If 0, no down-sampling will occur. Must be >= 0.",
+    ParamValidators.gtEq(0))
+
+  /** @group expertGetParam */
+  def getNumBins: Int = $(numBins)
+
+  setDefault(numBins -> 100000)
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+  }
+}
+
+
+/**
+ * Apply a power transform to make data more Gaussian-like.
+ * Currently, PowerTransform supports the Box-Cox transform and the Yeo-Johnson transform.
+ * The optimal parameter for stabilizing variance and minimizing skewness is estimated through
+ * maximum likelihood.
+ * Box-Cox requires input data to be strictly positive, while Yeo-Johnson supports both
+ * positive or negative data.
+ */
+@Since("3.1.0")
+class PowerTransform @Since("3.1.0")(@Since("3.1.0") override val uid: String)
+  extends Estimator[PowerTransformModel] with PowerTransformParams with DefaultParamsWritable {
+
+  import PowerTransform._
+
+  def this() = this(Identifiable.randomUID("power_trans"))
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setModelType(value: String): this.type = set(modelType, value)
+
+  /** @group expertSetParam */
+  def setNumBins(value: Int): this.type = set(numBins, value)
+
+  override def fit(dataset: Dataset[_]): PowerTransformModel = {
+    transformSchema(dataset.schema, logging = true)
+
+    val spark = dataset.sparkSession
+    import spark.implicits._
+
+    val localModelType = $(modelType)
+    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+    val numRows = dataset.count()
+
+    val validateFunc = $(modelType) match {
+      case BoxCox => vec: Vector => requirePositiveValues(vec)
+      case YeoJohnson => vec: Vector => requireNonNaNValues(vec)
+    }
+
+    var pairCounts = dataset
+      .select($(inputCol))
+      .flatMap { case Row(vec: Vector) =>
+        require(vec.size == numFeatures)
+        validateFunc(vec)
+        vec.nonZeroIterator
+      }.toDF("col", "value")
+      .groupBy("col", "value")
+      .agg(count(lit(0)).as("cnt"))
+      .sort("col", "value")
+
+    val groups = if (0 < $(numBins) && $(numBins) <= numRows) {
+      val localNumBins = $(numBins)
+      pairCounts
+        .groupBy("col")
+        .count()
+        .as[(Int, Long)]
+        .flatMap { case (col, num) =>
+          val group = num / localNumBins
+          if (group >= 2) {
+            Some((col, group))
+          } else {
+            None
+          }
+        }.collect().toMap
+    } else Map.empty[Int, Long]
+
+    if (groups.nonEmpty) {
+      pairCounts = makeBins(pairCounts.as[(Int, Double, Long)], groups)
+        .toDF("col", "value", "cnt")
+    }
+
+    val solutions = pairCounts
+      .groupBy("col")
+      .agg(collect_list(struct("value", "cnt")))
+      .as[(Int, Seq[(Double, Long)])]
+      .map { case (col, seq) =>
+        val nnz = seq.iterator.map(_._2).sum
+        val nz = numRows - nnz
+        val (solution, _) = localModelType match {
+          case BoxCox =>
+            require(nz >= 0)
+            val computeIter = if (nz > 0) {
+              () => seq.iterator ++ Iterator.single((0.0, nz))
+            } else {
+              () => seq.iterator
+            }
+            solveBoxCox(computeIter)
+          case YeoJohnson =>
+            require(nz == 0)
+            val computeIter = () => seq.iterator
+            solveYeoJohnson(computeIter)
+        }
+        (col, solution)
+      }.collect().toMap
+
+    val lambda = Array.ofDim[Double](numFeatures)
+    solutions.foreach { case (col, solution) => lambda(col) = solution }
+
+    if (solutions.size < numFeatures) {
+      localModelType match {
+        case YeoJohnson =>
+          // if some column only contains 0 values in YeoJohnson
+          val computeIter = () => Iterator.single((0.0, numRows))
+          val (zeroSolution, _) = solveYeoJohnson(computeIter)
+          Iterator.range(0, numFeatures)
+            .filterNot(solutions.contains)
+            .foreach { col => lambda(col) = zeroSolution }
+
+        case BoxCox =>
+          // This should never happen.
+          throw new IllegalArgumentException("BoxCox requires positive values")
+      }
+    }
+
+   copyValues(new PowerTransformModel(uid, Vectors.dense(lambda).compressed)
+    .setParent(this))
+  }
+
+  override def copy(extra: ParamMap): PowerTransform = defaultCopy(extra)
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+}
+
+
+@Since("3.1.0")
+object PowerTransform extends DefaultParamsReadable[PowerTransform] {
+
+  override def load(path: String): PowerTransform = super.load(path)
+
+  /** String name for Box-Cox transform model type. */
+  private[feature] val BoxCox: String = "box-cox"
+
+  /** String name for Yeo-Johnson transform model type. */
+  private[feature] val YeoJohnson: String = "yeo-johnson"
+
+  /* Set of modelTypes that PowerTransform supports */
+  private[feature] val supportedModelTypes = Array(BoxCox, YeoJohnson)
+
+  private[feature] def brentSolve(obj: UnivariateFunction): (Double, Double) = {
 
 Review comment:
   vs scikit-learn's [implementation](https://github.com/scikit-learn/scikit-learn/blob/b189bf60708af22dde82a00aca7b5a54290b666d/sklearn/preprocessing/_data.py#L3042):
   use same `tol` = 1.48E-8;
   sklearn uses bound [-2, 2], but it said "Providing the pair (xa,xb) does not always mean
           the obtained solution will satisfy xa<=x<=xb.";
   sklearn use iters=500;
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org