You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/01/18 23:33:47 UTC

spark git commit: [SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces

Repository: spark
Updated Branches:
  refs/heads/master a81e336f1 -> fe409f31d


[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces

## What changes were proposed in this pull request?

For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.

## How was this patch tested?

The basic ML tests were run after making the changes.  I've marked this as WIP as I need to add more tests.

Author: Ilya Matiach <il...@microsoft.com>

Closes #16441 from imatiach-msft/ilmat/fix-GBT.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fe409f31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fe409f31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fe409f31

Branch: refs/heads/master
Commit: fe409f31d966d99fcf57137581d1fb682c1c072a
Parents: a81e336
Author: Ilya Matiach <il...@microsoft.com>
Authored: Wed Jan 18 15:33:41 2017 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Jan 18 15:33:41 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala |  94 ++++++++---
 .../org/apache/spark/ml/tree/treeParams.scala   |   4 +-
 .../apache/spark/mllib/tree/loss/LogLoss.scala  |  10 +-
 .../org/apache/spark/mllib/tree/loss/Loss.scala |   8 +-
 .../ml/classification/GBTClassifierSuite.scala  | 161 ++++++++++++++++++-
 5 files changed, 248 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fe409f31/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c9bbd37..ade0960 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -23,9 +23,8 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
@@ -33,6 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util._
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.LogLoss
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
 @Since("1.4.0")
 class GBTClassifier @Since("1.4.0") (
     @Since("1.4.0") override val uid: String)
-  extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
+  extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
   with GBTClassifierParams with DefaultParamsWritable with Logging {
 
   @Since("1.4.0")
@@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
     val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
 
+    val numClasses = 2
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
+
     val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
       maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
       seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
     instr.logNumFeatures(numFeatures)
-    instr.logNumClasses(2)
+    instr.logNumClasses(numClasses)
 
     val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
       $(seed))
@@ -202,8 +209,9 @@ class GBTClassificationModel private[ml](
     @Since("1.6.0") override val uid: String,
     private val _trees: Array[DecisionTreeRegressionModel],
     private val _treeWeights: Array[Double],
-    @Since("1.6.0") override val numFeatures: Int)
-  extends PredictionModel[Vector, GBTClassificationModel]
+    @Since("1.6.0") override val numFeatures: Int,
+    @Since("2.2.0") override val numClasses: Int)
+  extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
   with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
   with MLWritable with Serializable {
 
@@ -216,10 +224,24 @@ class GBTClassificationModel private[ml](
    *
    * @param _trees  Decision trees in the ensemble.
    * @param _treeWeights  Weights for the decision trees in the ensemble.
+   * @param numFeatures  The number of features.
+   */
+  private[ml] def this(
+      uid: String,
+      _trees: Array[DecisionTreeRegressionModel],
+      _treeWeights: Array[Double],
+      numFeatures: Int) =
+  this(uid, _trees, _treeWeights, numFeatures, 2)
+
+  /**
+   * Construct a GBTClassificationModel
+   *
+   * @param _trees  Decision trees in the ensemble.
+   * @param _treeWeights  Weights for the decision trees in the ensemble.
    */
   @Since("1.6.0")
   def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
-    this(uid, _trees, _treeWeights, -1)
+    this(uid, _trees, _treeWeights, -1, 2)
 
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
@@ -242,11 +264,29 @@ class GBTClassificationModel private[ml](
   }
 
   override protected def predict(features: Vector): Double = {
-    // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
-    // Classifies by thresholding sum of weighted tree predictions
-    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
-    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
-    if (prediction > 0.0) 1.0 else 0.0
+    // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
+    if (isDefined(thresholds)) {
+      super.predict(features)
+    } else {
+      if (margin(features) > 0.0) 1.0 else 0.0
+    }
+  }
+
+  override protected def predictRaw(features: Vector): Vector = {
+    val prediction: Double = margin(features)
+    Vectors.dense(Array(-prediction, prediction))
+  }
+
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction match {
+      case dv: DenseVector =>
+        dv.values(0) = loss.computeProbability(dv.values(0))
+        dv.values(1) = 1.0 - dv.values(0)
+        dv
+      case sv: SparseVector =>
+        throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
+          " raw2probabilityInPlace encountered SparseVector")
+    }
   }
 
   /** Number of trees in ensemble */
@@ -254,7 +294,7 @@ class GBTClassificationModel private[ml](
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): GBTClassificationModel = {
-    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
+    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
       extra).setParent(parent)
   }
 
@@ -276,11 +316,20 @@ class GBTClassificationModel private[ml](
   @Since("2.0.0")
   lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
 
+  /** Raw prediction for the positive class. */
+  private def margin(features: Vector): Double = {
+    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
+    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+  }
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
   }
 
+  // hard coded loss, which is not meant to be changed in the model
+  private val loss = getOldLossType
+
   @Since("2.0.0")
   override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
 }
@@ -288,6 +337,9 @@ class GBTClassificationModel private[ml](
 @Since("2.0.0")
 object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
 
+  private val numFeaturesKey: String = "numFeatures"
+  private val numTreesKey: String = "numTrees"
+
   @Since("2.0.0")
   override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
 
@@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
     override protected def saveImpl(path: String): Unit = {
 
       val extraMetadata: JObject = Map(
-        "numFeatures" -> instance.numFeatures,
-        "numTrees" -> instance.getNumTrees)
+        numFeaturesKey -> instance.numFeatures,
+        numTreesKey -> instance.getNumTrees)
       EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
     }
   }
@@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
         EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
-      val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
-      val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+      val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
+      val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map {
         case (treeMetadata, root) =>
@@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
       }
       require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
         s" trees based on metadata but found ${trees.length} trees.")
-      val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
+      val model = new GBTClassificationModel(metadata.uid,
+        trees, treeWeights, numFeatures)
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }
@@ -339,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
       oldModel: OldGBTModel,
       parent: GBTClassifier,
       categoricalFeatures: Map[Int, Int],
-      numFeatures: Int = -1): GBTClassificationModel = {
+      numFeatures: Int = -1,
+      numClasses: Int = 2): GBTClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
       s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -347,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
       DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
-    new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
+    new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe409f31/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index c7a8f76..5eb707d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
 /**
@@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
   def getLossType: String = $(lossType).toLowerCase
 
   /** (private[ml]) Convert new loss to old loss. */
-  override private[ml] def getOldLossType: OldLoss = {
+  override private[ml] def getOldLossType: OldClassificationLoss = {
     getLossType match {
       case "logistic" => OldLogLoss
       case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fe409f31/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 5d92ce4..9339f0a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.mllib.util.MLUtils
 
-
 /**
  * :: DeveloperApi ::
  * Class for log loss calculation (for classification).
@@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils
  */
 @Since("1.2.0")
 @DeveloperApi
-object LogLoss extends Loss {
+object LogLoss extends ClassificationLoss {
 
   /**
    * Method to calculate the loss gradients for the gradient boosting calculation for binary
@@ -52,4 +51,11 @@ object LogLoss extends Loss {
     // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
     2.0 * MLUtils.log1pExp(-margin)
   }
+
+  /**
+   * Returns the estimated probability of a label of 1.0.
+   */
+  override private[spark] def computeProbability(margin: Double): Double = {
+    1.0 / (1.0 + math.exp(-2.0 * margin))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe409f31/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 09274a2..e7ffb3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
-
 /**
  * :: DeveloperApi ::
  * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -67,3 +66,10 @@ trait Loss extends Serializable {
    */
   private[spark] def computeError(prediction: Double, label: Double): Double
 }
+
+private[spark] trait ClassificationLoss extends Loss {
+  /**
+   * Computes the class probability given the margin.
+   */
+  private[spark] def computeProbability(margin: Double): Double
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fe409f31/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 7c36745..0598943 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,20 +17,24 @@
 
 package org.apache.spark.ml.classification
 
+import com.github.fommil.netlib.BLAS
+
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.LogLoss
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.util.Utils
 
 /**
@@ -49,6 +53,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
   private var data: RDD[LabeledPoint] = _
   private var trainData: RDD[LabeledPoint] = _
   private var validationData: RDD[LabeledPoint] = _
+  private val eps: Double = 1e-5
+  private val absEps: Double = 1e-8
 
   override def beforeAll() {
     super.beforeAll()
@@ -66,10 +72,156 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     ParamsSuite.checkParams(new GBTClassifier)
     val model = new GBTClassificationModel("gbtc",
       Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
-      Array(1.0), 1)
+      Array(1.0), 1, 2)
     ParamsSuite.checkParams(model)
   }
 
+  test("GBTClassifier: default params") {
+    val gbt = new GBTClassifier
+    assert(gbt.getLabelCol === "label")
+    assert(gbt.getFeaturesCol === "features")
+    assert(gbt.getPredictionCol === "prediction")
+    assert(gbt.getRawPredictionCol === "rawPrediction")
+    assert(gbt.getProbabilityCol === "probability")
+    val df = trainData.toDF()
+    val model = gbt.fit(df)
+    model.transform(df)
+      .select("label", "probability", "prediction", "rawPrediction")
+      .collect()
+    intercept[NoSuchElementException] {
+      model.getThresholds
+    }
+    assert(model.getFeaturesCol === "features")
+    assert(model.getPredictionCol === "prediction")
+    assert(model.getRawPredictionCol === "rawPrediction")
+    assert(model.getProbabilityCol === "probability")
+    assert(model.hasParent)
+
+    // copied model must have the same parent.
+    MLTestingUtils.checkCopy(model)
+  }
+
+  test("setThreshold, getThreshold") {
+    val gbt = new GBTClassifier
+
+    // default
+    withClue("GBTClassifier should not have thresholds set by default.") {
+      intercept[NoSuchElementException] {
+        gbt.getThresholds
+      }
+    }
+
+    // Set via thresholds
+    val gbt2 = new GBTClassifier
+    val threshold = Array(0.3, 0.7)
+    gbt2.setThresholds(threshold)
+    assert(gbt2.getThresholds === threshold)
+  }
+
+  test("thresholds prediction") {
+    val gbt = new GBTClassifier
+    val df = trainData.toDF()
+    val binaryModel = gbt.fit(df)
+
+    // should predict all zeros
+    binaryModel.setThresholds(Array(0.0, 1.0))
+    val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
+    assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+
+    // should predict all ones
+    binaryModel.setThresholds(Array(1.0, 0.0))
+    val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
+    assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
+
+
+    val gbtBase = new GBTClassifier
+    val model = gbtBase.fit(df)
+    val basePredictions = model.transform(df).select("prediction").collect()
+
+    // constant threshold scaling is the same as no thresholds
+    binaryModel.setThresholds(Array(1.0, 1.0))
+    val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
+    assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+      scaled.getDouble(0) === base.getDouble(0)
+    })
+
+    // force it to use the predict method
+    model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
+    val predictionsWithPredict = model.transform(df).select("prediction").collect()
+    assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+  }
+
+  test("GBTClassifier: Predictor, Classifier methods") {
+    val rawPredictionCol = "rawPrediction"
+    val predictionCol = "prediction"
+    val labelCol = "label"
+    val featuresCol = "features"
+    val probabilityCol = "probability"
+
+    val gbt = new GBTClassifier().setSeed(123)
+    val trainingDataset = trainData.toDF(labelCol, featuresCol)
+    val gbtModel = gbt.fit(trainingDataset)
+    assert(gbtModel.numClasses === 2)
+    val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size
+    assert(gbtModel.numFeatures === numFeatures)
+
+    val blas = BLAS.getInstance()
+
+    val validationDataset = validationData.toDF(labelCol, featuresCol)
+    val results = gbtModel.transform(validationDataset)
+    // check that raw prediction is tree predictions dot tree weights
+    results.select(rawPredictionCol, featuresCol).collect().foreach {
+      case Row(raw: Vector, features: Vector) =>
+        assert(raw.size === 2)
+        val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
+        val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
+        assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
+    }
+
+    // Compare rawPrediction with probability
+    results.select(rawPredictionCol, probabilityCol).collect().foreach {
+      case Row(raw: Vector, prob: Vector) =>
+        assert(raw.size === 2)
+        assert(prob.size === 2)
+        // Note: we should check other loss types for classification if they are added
+        val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
+        assert(prob(0) ~== predFromRaw(0) relTol eps)
+        assert(prob(1) ~== predFromRaw(1) relTol eps)
+        assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
+    }
+
+    // Compare prediction with probability
+    results.select(predictionCol, probabilityCol).collect().foreach {
+      case Row(pred: Double, prob: Vector) =>
+        val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+        assert(pred == predFromProb)
+    }
+
+    // force it to use raw2prediction
+    gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
+    val resultsUsingRaw2Predict =
+      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+    resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use probability2prediction
+    gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
+    val resultsUsingProb2Predict =
+      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+    resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use predict
+    gbtModel.setRawPredictionCol("").setProbabilityCol("")
+    val resultsUsingPredict =
+      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+    resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+  }
+
   test("GBT parameter stepSize should be in interval (0, 1]") {
     withClue("GBT parameter stepSize should be in interval (0, 1]") {
       intercept[IllegalArgumentException] {
@@ -246,7 +398,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
     val newModel = gbt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = GBTClassificationModel.fromOld(
-      oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
+      oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures,
+      numFeatures, numClasses = 2)
     TreeTests.checkEqual(oldModelAsNew, newModel)
     assert(newModel.numFeatures === numFeatures)
     assert(oldModelAsNew.numFeatures === numFeatures)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org