You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by smurakozi <gi...@git.apache.org> on 2018/01/17 14:31:33 UTC

[GitHub] spark pull request #20121: [SPARK-22882][ML][TESTS] ML test for structured s...

Github user smurakozi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20121#discussion_r162037613
  
    --- Diff: mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---
    @@ -169,59 +171,28 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
         val blas = BLAS.getInstance()
     
         val validationDataset = validationData.toDF(labelCol, featuresCol)
    -    val results = gbtModel.transform(validationDataset)
    -    // check that raw prediction is tree predictions dot tree weights
    -    results.select(rawPredictionCol, featuresCol).collect().foreach {
    -      case Row(raw: Vector, features: Vector) =>
    +    testTransformer[(Double, Vector)](validationDataset, gbtModel,
    +      "rawPrediction", "features", "probability", "prediction") {
    +      case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) =>
             assert(raw.size === 2)
    +        // check that raw prediction is tree predictions dot tree weights
             val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
             val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
             assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
    -    }
     
    -    // Compare rawPrediction with probability
    -    results.select(rawPredictionCol, probabilityCol).collect().foreach {
    -      case Row(raw: Vector, prob: Vector) =>
    -        assert(raw.size === 2)
    +        // Compare rawPrediction with probability
             assert(prob.size === 2)
             // Note: we should check other loss types for classification if they are added
             val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
             assert(prob(0) ~== predFromRaw(0) relTol eps)
             assert(prob(1) ~== predFromRaw(1) relTol eps)
             assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
    -    }
     
    -    // Compare prediction with probability
    -    results.select(predictionCol, probabilityCol).collect().foreach {
    -      case Row(pred: Double, prob: Vector) =>
    +        // Compare prediction with probability
             val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
             assert(pred == predFromProb)
         }
     
    -    // force it to use raw2prediction
    -    gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
    -    val resultsUsingRaw2Predict =
    -      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
    -    resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
    -      case (pred1, pred2) => assert(pred1 === pred2)
    -    }
    -
    -    // force it to use probability2prediction
    -    gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
    -    val resultsUsingProb2Predict =
    -      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
    -    resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
    -      case (pred1, pred2) => assert(pred1 === pred2)
    -    }
    -
    -    // force it to use predict
    --- End diff --
    
    Why were these transformations and checks removed?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org