You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by smurching <gi...@git.apache.org> on 2017/12/01 07:57:09 UTC

[GitHub] spark pull request #19758: [SPARK-3162][MLlib] Local Tree Training Pt 1: Ref...

Github user smurching commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19758#discussion_r154284858
  
    --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala ---
    @@ -0,0 +1,280 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split}
    +import org.apache.spark.ml.util.DefaultReadWriteTest
    +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity}
    +import org.apache.spark.mllib.tree.model.ImpurityStats
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +/** Suite exercising helper methods for making split decisions during decision tree training. */
    +class TreeSplitUtilsSuite
    +  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
    +
    +  /**
    +   * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
    +   * with the data from the specified training points. Assumes a feature index of 0 and that
    +   * all training points have the same weights (1.0).
    +   */
    +  private def getAggregator(
    +      metadata: DecisionTreeMetadata,
    +      values: Array[Int],
    +      labels: Array[Double],
    +      featureSplits: Array[Split]): DTStatsAggregator = {
    +    // Create stats aggregator
    +    val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None)
    +    // Update parent impurity stats
    +    val featureIndex = 0
    +    val instanceWeights = Array.fill[Double](values.length)(1.0)
    +    AggUpdateUtils.updateParentImpurity(statsAggregator, indices = values.indices.toArray,
    +      from = 0, to = values.length, instanceWeights, labels)
    +    // Update current aggregator's impurity stats
    +    values.zip(labels).foreach { case (value: Int, label: Double) =>
    +      if (metadata.isUnordered(featureIndex)) {
    +        AggUpdateUtils.updateUnorderedFeature(statsAggregator, value, label,
    +          featureIndex = featureIndex, featureIndexIdx = 0, featureSplits, instanceWeight = 1.0)
    +      } else {
    +        AggUpdateUtils.updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0,
    +          instanceWeight = 1.0)
    +      }
    +    }
    +    statsAggregator
    +  }
    +
    +  /**
    +   * Check that left/right impurities match what we'd expect for a split.
    +   * @param labels Labels whose impurity information should be reflected in stats
    +   * @param stats ImpurityStats object containing impurity info for the left/right sides of a split
    +   */
    +  private def validateImpurityStats(
    +      impurity: Impurity,
    +      labels: Array[Double],
    +      stats: ImpurityStats,
    +      expectedLeftStats: Array[Double],
    +      expectedRightStats: Array[Double]): Unit = {
    +    // Compute impurity for our data points manually
    +    val numClasses = (labels.max + 1).toInt
    +    val fullImpurityStatsArray
    +      = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble)
    +    val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length)
    +    // Verify that impurity stats were computed correctly for split
    +    assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
    +    assert(stats.impurity === fullImpurity)
    +    assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
    +    assert(stats.rightImpurityCalculator.stats === expectedRightStats)
    +    assert(stats.valid)
    +  }
    +
    +  /* * * * * * * * * * * Choosing Splits  * * * * * * * * * * */
    +
    +  test("chooseSplit: choose correct type of split (continuous split)") {
    +    // Construct (binned) continuous data
    +    val labels = Array(0.0, 0.0, 1.0)
    +    val values = Array(1, 2, 3)
    +    val featureIndex = 0
    +    // Get an array of continuous splits corresponding to values in our binned data
    +    val splits = TreeTests.getContinuousSplits(thresholds = values.distinct.sorted,
    +      featureIndex = 0)
    +    // Construct DTStatsAggregator, compute sufficient stats
    +    val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
    +      numClasses = 2, Map.empty)
    +    val statsAggregator = getAggregator(metadata, values, labels, splits)
    +    // Choose split, check that it's a valid ContinuousSplit
    +    val (split, stats) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex,
    +      splits)
    +    assert(stats.valid && split.isInstanceOf[ContinuousSplit])
    +  }
    +
    +  test("chooseSplit: choose correct type of split (categorical split)") {
    +    // Construct categorical data
    +    val labels = Array(0.0, 0.0, 1.0, 1.0, 1.0)
    +    val featureArity = 3
    +    val values = Array(0, 0, 1, 2, 2)
    +    val featureIndex = 0
    +    // Construct DTStatsAggregator, compute sufficient stats
    +    val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
    +      numClasses = 2, Map(featureIndex -> featureArity))
    +    val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
    +    val statsAggregator = getAggregator(metadata, values, labels, splits)
    +    // Choose split, check that it's a valid categorical split
    +    val (split, stats) = SplitUtils.chooseSplit(statsAggregator = statsAggregator,
    +      featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits)
    +    assert(stats.valid && split.isInstanceOf[CategoricalSplit])
    +  }
    +
    +  test("chooseOrderedCategoricalSplit: basic case") {
    +    // Helper method for testing ordered categorical split
    +    def testHelper(
    +        values: Array[Int],
    +        labels: Array[Double],
    +        expectedLeftCategories: Array[Double],
    +        expectedLeftStats: Array[Double],
    +        expectedRightStats: Array[Double]): Unit = {
    +      // Set up metadata for ordered categorical feature
    +      val featureIndex = 0
    +      val featureArity = values.max + 1
    +      val arityMap = Map[Int, Int](featureIndex -> featureArity)
    +      val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
    +        numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty))
    +      // Construct DTStatsAggregator, compute sufficient stats
    +      val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty)
    +      // Choose split
    +      val (split, stats) =
    +        SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
    +      // Verify that split has the expected left-side/right-side categories
    +      val expectedRightCategories = Range(0, featureArity)
    +        .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray
    +      split match {
    +        case s: CategoricalSplit =>
    +          assert(s.featureIndex === featureIndex)
    +          assert(s.leftCategories === expectedLeftCategories)
    +          assert(s.rightCategories === expectedRightCategories)
    +        case _ =>
    +          throw new AssertionError(
    +            s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
    +      }
    +      validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats)
    +    }
    +
    +    // Test a single split: The left side of our split should contain the two points with label 0,
    +    // the left side of our split should contain the five points with label 1
    +    val values = Array(0, 0, 1, 2, 2, 2, 2)
    +    val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble)
    +    testHelper(values, labels1, expectedLeftCategories = Array(0.0),
    +      expectedLeftStats = Array(2.0, 0.0), expectedRightStats = Array(0.0, 5.0))
    +
    +    // Test a single split: The left side of our split should contain the three points with label 0,
    +    // the left side of our split should contain the four points with label 1
    +    val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble)
    +    testHelper(values, labels2, expectedLeftCategories = Array(0.0, 1.0),
    +      expectedLeftStats = Array(3.0, 0.0), expectedRightStats = Array(0.0, 4.0))
    +  }
    +
    +  test("chooseOrderedCategoricalSplit: return bad stats if we should not split") {
    +    // Construct categorical data
    +    val featureIndex = 0
    +    val values = Array(0, 0, 1, 2, 2, 2, 2)
    +    val featureArity = values.max + 1
    +    val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
    +    // Construct DTStatsAggregator, compute sufficient stats
    +    val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
    +      numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty))
    +    val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty)
    +    // Choose split, verify that it's invalid
    +    val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
    +        featureIndex)
    +    assert(!stats.valid)
    +  }
    +
    +  test("chooseUnorderedCategoricalSplit: basic case") {
    +    val featureIndex = 0
    +    // Construct data for unordered categorical feature
    +    // label: 0 --> values: 1
    +    // label: 1 --> values: 0, 2
    +    // label: 2 --> values: 2
    +    // Expected split: feature value 1 on the left, values (0, 2) on the right
    +    val values = Array(1, 1, 0, 2, 2)
    +    val featureArity = values.max + 1
    --- End diff --
    
    @WeichenXu123 thanks for the feedback! Definitely agree that the test is a little weak right now.
    
    IMO it's mainly weak due to the low feature arity (there only three possible splits, so the right one could be picked by chance). I think increasing the number of classes/examples substantially might make the test harder to reason about, but not opposed to that either - let me know what you think.
    
    What about something like:
    
    ```
        val values = Array(0, 1, 2, 3, 2, 2, 4)
        val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)
        // label: 0 --> values: 0, 1
        // label: 1 --> values: 2, 3
        // label: 2 --> values: 2, 2, 4
        // Expected split: feature values (0, 1) on the left, values (2, 3, 4) on the right
    ```
    
    This way we still test multiclass classification & test the split-selection logic more rigorously.
    



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org