You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/01/04 22:32:27 UTC

spark git commit: [SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction

Repository: spark
Updated Branches:
  refs/heads/master ba5f81859 -> 93ef9b6a2


[SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction

DecisionTreeRegressor will provide variance of prediction as a Double column.

Author: Yanbo Liang <yb...@gmail.com>

Closes #8866 from yanboliang/spark-9622.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93ef9b6a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93ef9b6a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93ef9b6a

Branch: refs/heads/master
Commit: 93ef9b6a2aa1830170cb101f191022f2dda62c41
Parents: ba5f818
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Jan 4 13:32:14 2016 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jan 4 13:32:14 2016 -0800

----------------------------------------------------------------------
 .../ml/param/shared/SharedParamsCodeGen.scala   |  1 +
 .../spark/ml/param/shared/sharedParams.scala    | 15 ++++++++
 .../ml/regression/DecisionTreeRegressor.scala   | 36 ++++++++++++++++++--
 .../org/apache/spark/ml/tree/treeParams.scala   | 18 ++++++++++
 .../regression/DecisionTreeRegressorSuite.scala | 26 +++++++++++++-
 5 files changed, 92 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index c7bca12..4aff749 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen {
         " probabilities. Note: Not all models output well-calibrated probability estimates!" +
         " These probabilities should be treated as confidences, not precise probabilities",
         Some("\"probability\"")),
+      ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"),
       ParamDesc[Double]("threshold",
         "threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
         isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),

http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index cb2a060..c088c16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -139,6 +139,21 @@ private[ml] trait HasProbabilityCol extends Params {
 }
 
 /**
+ * Trait for shared param varianceCol.
+ */
+private[ml] trait HasVarianceCol extends Params {
+
+  /**
+   * Param for Column name for the biased sample variance of prediction.
+   * @group param
+   */
+  final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction")
+
+  /** @group getParam */
+  final def getVarianceCol: String = $(varianceCol)
+}
+
+/**
  * Trait for shared param threshold (default: 0.5).
  */
 private[ml] trait HasThreshold extends Params {

http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 477030d..18c94f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
+import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
@@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
 
 /**
  * :: Experimental ::
@@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame
 @Experimental
 final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
   extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
-  with DecisionTreeParams with TreeRegressorParams {
+  with DecisionTreeRegressorParams {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("dtr"))
@@ -73,6 +74,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
 
   override def setSeed(value: Long): this.type = super.setSeed(value)
 
+  /** @group setParam */
+  def setVarianceCol(value: String): this.type = set(varianceCol, value)
+
   override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -113,7 +117,10 @@ final class DecisionTreeRegressionModel private[ml] (
     override val rootNode: Node,
     override val numFeatures: Int)
   extends PredictionModel[Vector, DecisionTreeRegressionModel]
-  with DecisionTreeModel with Serializable {
+  with DecisionTreeModel with DecisionTreeRegressorParams with Serializable {
+
+  /** @group setParam */
+  def setVarianceCol(value: String): this.type = set(varianceCol, value)
 
   require(rootNode != null,
     "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
@@ -129,6 +136,29 @@ final class DecisionTreeRegressionModel private[ml] (
     rootNode.predictImpl(features).prediction
   }
 
+  /** We need to update this function if we ever add other impurity measures. */
+  protected def predictVariance(features: Vector): Double = {
+    rootNode.predictImpl(features).impurityStats.calculate()
+  }
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+    transformImpl(dataset)
+  }
+
+  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+    val predictUDF = udf { (features: Vector) => predict(features) }
+    val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
+    var output = dataset
+    if ($(predictionCol).nonEmpty) {
+      output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+    }
+    if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+      output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol))))
+    }
+    output
+  }
+
   @Since("1.4.0")
   override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
     copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)

http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 1da97db..7443097 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,9 +20,11 @@ package org.apache.spark.ml.tree
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
 import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
 
 /**
  * Parameters for Decision Tree-based algorithms.
@@ -256,6 +258,22 @@ private[ml] object TreeRegressorParams {
   final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
 }
 
+private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
+  with TreeRegressorParams with HasVarianceCol {
+
+  override protected def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+    if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+      SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType)
+    } else {
+      newSchema
+    }
+  }
+}
+
 /**
  * Parameters for Decision Tree-based ensemble algorithms.
  *

http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 6999a91..0b39af5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.regression
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
 import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
   DecisionTreeSuite => OldDecisionTreeSuite}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Row, DataFrame}
 
 
 class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
     MLTestingUtils.checkCopy(model)
   }
 
+  test("predictVariance") {
+    val dt = new DecisionTreeRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(2)
+      .setMaxBins(100)
+      .setPredictionCol("")
+      .setVarianceCol("variance")
+    val categoricalFeatures = Map(0 -> 2, 1 -> 2)
+
+    val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+    val model = dt.fit(df)
+
+    val predictions = model.transform(df)
+      .select(model.getFeaturesCol, model.getVarianceCol)
+      .collect()
+
+    predictions.foreach { case Row(features: Vector, variance: Double) =>
+      val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
+      assert(variance === expectedVariance,
+        s"Expected variance $expectedVariance but got $variance.")
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org