You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by jkbradley <gi...@git.apache.org> on 2015/04/15 18:31:29 UTC

[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

GitHub user jkbradley opened a pull request:

    https://github.com/apache/spark/pull/5530

    [SPARK-6113] [ml] Stabilize DecisionTree API

    This is a PR for cleaning up and finalizing the DecisionTree API.  PRs for ensembles will follow once this is merged.
    
    ### Goal
    
    Here is the description copied from the JIRA (for both trees and ensembles):
    
    > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design.
    > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details.
    > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API.
    
    Overall code layout:
    * The old API in mllib.tree.* will remain the same.
    * The new API will reside in ml.classification.* and ml.regression.*
    
    ### Summary of changes
    
    Old API
    * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release).
    
    New APIs
    * Under Pipeline API
    * The new API preserves functionality, except:
      * New API does NOT store prob (probability of label in classification).  I want to have it store the full vector of probabilities but feel that should be in a later PR.
    * Use abstractions for parameters, estimators, and models to avoid code duplication
    * Limit parameters to relevant algorithms
    * For enum-like types, only expose Strings
      * We can make these pluggable later on by adding new parameters.  That is a far-future item.
    
    Test suites
    * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves.
    * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API.
      * After code is moved to this new API, we should move the tests from the old suites which test the internals.
    
    ### Details
    
    #### Changed names
    
    Parameters
    * useNodeIdCache -> cacheNodeIds
    
    #### Other changes
    
    * Split: Changed categories to set instead of list
    
    #### Non-decision tree changes
    * AttributeGroup
      * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.)
    * Attributes
      * Renamed: toMetadata -> toMetadataImpl
      * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”)
      * NominalAttribute: Added getNumValues method which examines both numValues and values.
    * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters)
    
    ### Questions for reviewers
    
    * Is "DecisionTreeClassificationModel" too long a name?
    * Is this OK in the docs?
    ```
    class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor]
    ```
    
    ### Future
    
    We should open up the abstractions at some point.  E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms.
    
    Follow-up JIRAs will be (in this order):
    * Tree ensembles
    * Deprecate old tree code
    * Move DecisionTree implementation code to new API.
    * Move tests from the old suites which test the internals.
    * Update programming guide
    * Python API
    * Change RandomForest* to always use bootstrapping, even when numTrees = 1
    * Provide the probability of the predicted label for classification.  After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API.
    
    CC: @mengxr  @manishamde  @codedeft  @chouqin  @MechCoder

You can merge this pull request into a Git repository by running:

    $ git pull https://github.com/jkbradley/spark dt-api-dt

Alternatively you can review and apply these changes as the patch at:

    https://github.com/apache/spark/pull/5530.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

    This closes #5530
    
----
commit c72c1a01387bfd0e672f71f385f9a31e9d5d2c1e
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-01T03:55:22Z

    Copied changes for common items, plus DecisionTreeClassifier from original PR

commit 2532c9a9189cc01592a6f9d7d49d10055f918d7c
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-02T03:08:08Z

    partial move to spark.ml API, not done yet

commit f9fbb605f503a91f4a998cd32fee11510dfd341c
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-14T05:07:15Z

    Done with DecisionTreeClassifier, but no save/load yet.  Need to add example as well

commit 0bdc486e426abaef4c2c3619280450601a458ab8
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-14T06:28:13Z

    fixed issues after param PR was merged

commit 119f407231f52d9339dd8b821b6fe652e6b695b8
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-14T22:11:20Z

    added DecisionTreeClassifier example

commit e11673f8994314add5d2a749c1ab808f126d2bca
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-14T23:53:03Z

    Added DecisionTreeRegressor, test suites, and example

commit 7ef63ed593cbcaa87b0078b548c6c7738499d7b3
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-14T23:53:19Z

    Added DecisionTreeRegressor, test suites, and example (for real this time)

commit f8fbd24877c522138f8d16d2c1855c498b83ba0c
Author: Joseph K. Bradley <jo...@databricks.com>
Date:   2015-04-15T16:23:58Z

    imported reorg of DecisionTreeSuite from old PR.  small cleanups

----


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487018
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala ---
    @@ -0,0 +1,300 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.impl.estimator.PredictorParams
    +import org.apache.spark.ml.param._
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
    +  Impurity => OldImpurity, Variance => OldVariance}
    +
    +
    +/**
    + * :: DeveloperApi ::
    + * Parameters for Decision Tree-based algorithms.
    + * @tparam M  Concrete class implementing this parameter trait
    + */
    +@DeveloperApi
    +private[ml] trait DecisionTreeParams[M] extends PredictorParams {
    +
    +  /**
    +   * Maximum depth of the tree.
    +   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
    +   * (default = 5)
    +   * @group param
    +   */
    +  val maxDepth: IntParam =
    --- End diff --
    
    `final`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487026
  
    --- Diff: mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java ---
    @@ -0,0 +1,97 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification;
    +
    +import java.io.File;
    +import java.io.Serializable;
    +import java.util.HashMap;
    +import java.util.Map;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.ml.impl.TreeTests;
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite;
    +import org.apache.spark.mllib.regression.LabeledPoint;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.util.Utils;
    +
    +
    +public class JavaDecisionTreeClassifierSuite implements Serializable {
    +
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void runDT() {
    +    int nPoints = 20;
    +    double A = 2.0;
    +    double B = -1.5;
    +
    +    JavaRDD<LabeledPoint> data = sc.parallelize(
    +        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
    +    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
    +    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
    +
    +    DecisionTreeClassifier dt = new DecisionTreeClassifier()
    +      .setMaxDepth(2)
    +      .setMaxBins(10)
    +      .setMinInstancesPerNode(5)
    +      .setMinInfoGain(0.0)
    +      .setMaxMemoryInMB(256)
    +      .setCacheNodeIds(false)
    +      .setCheckpointInterval(10)
    +      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
    +    for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
    +      dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
    +    }
    +    DecisionTreeClassificationModel model = dt.fit(dataFrame);
    --- End diff --
    
    Move this line inside the for loop?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28532955
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    +        case "regression" =>
    +          data
    +        case _ =>
    +          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +      }
    +    }
    +    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
    +
    +    (dataframes(0), dataframes(1))
    +  }
    +
    +  def run(params: Params) {
    +    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
    +    val sc = new SparkContext(conf)
    +    params.checkpointDir.foreach(sc.setCheckpointDir)
    +    val algo = params.algo.toLowerCase
    +
    +    println(s"DecisionTreeExample with parameters:\n$params")
    +
    +    // Load training and test data and cache it.
    +    val (training: DataFrame, test: DataFrame) =
    +      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
    +
    +    val numTraining = training.count()
    +    val numTest = test.count()
    +    val numFeatures = training.select("features").first().getAs[Vector](0).size
    +    println("Loaded data:")
    +    println(s"  numTraining = $numTraining, numTest = $numTest")
    +    println(s"  numFeatures = $numFeatures")
    +
    +    // Set up Pipeline
    +    val stages = new mutable.ArrayBuffer[PipelineStage]()
    +    // (1) For classification, re-index classes.
    +    if (algo == "classification") {
    +      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol("label")
    +      stages += labelIndexer
    +    }
    +    // (2) Identify categorical features using VectorIndexer.
    +    //     Features with more than maxCategories values will be treated as continuous.
    +    val featuresIndexer = new VectorIndexer().setInputCol("features")
    --- End diff --
    
    Does "chop down" mean put 1 per line?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476291
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---
    @@ -0,0 +1,155 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
    +import org.apache.spark.ml.impl.tree._
    +import org.apache.spark.ml.param.{Params, ParamMap}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
    + * for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassifier
    +  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
    +  with DecisionTreeParams[DecisionTreeClassifier]
    +  with TreeClassifierParams[DecisionTreeClassifier] {
    +
    +  // Override parameter setters from parent trait for Java API compatibility.
    +
    +  override def setMaxDepth(maxDepth: Int): DecisionTreeClassifier = super.setMaxDepth(maxDepth)
    --- End diff --
    
    About the argument name, I think it is redundant to mention `maxDepth` twice. `value` is used for existing implementations.
    
    Should it return `this.type` (if we return `this.type` in the parent class?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487021
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala ---
    @@ -0,0 +1,300 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.impl.estimator.PredictorParams
    +import org.apache.spark.ml.param._
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
    +  Impurity => OldImpurity, Variance => OldVariance}
    +
    +
    +/**
    + * :: DeveloperApi ::
    + * Parameters for Decision Tree-based algorithms.
    + * @tparam M  Concrete class implementing this parameter trait
    + */
    +@DeveloperApi
    +private[ml] trait DecisionTreeParams[M] extends PredictorParams {
    +
    +  /**
    +   * Maximum depth of the tree.
    +   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
    +   * (default = 5)
    +   * @group param
    +   */
    +  val maxDepth: IntParam =
    +    new IntParam(this, "maxDepth", "Maximum depth of the tree." +
    +      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
    +
    +  /**
    +   * Maximum number of bins used for discretizing continuous features and for choosing how to split
    +   * on features at each node.  More bins give higher granularity.
    +   * Must be >= 2 and >= number of categories in any categorical feature.
    +   * (default = 32)
    +   * @group param
    +   */
    +  val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for discretizing" +
    +    " continuous features.  Must be >=2 and >= number of categories for any categorical feature.")
    +
    +  /**
    +   * Minimum number of instances each child must have after split.
    +   * If a split cause left or right child to have less than minInstancesPerNode,
    +   * this split will not be considered as a valid split.
    +   * Should be >= 1.
    +   * (default = 1)
    +   * @group param
    +   */
    +  val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum number" +
    +    " of instances each child must have after split.  If a split cause left or right child to" +
    --- End diff --
    
    `causes`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93478358
  
      [Test build #30352 has started](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30352/consoleFull) for   PR 5530 at commit [`f8fbd24`](https://github.com/apache/spark/commit/f8fbd24877c522138f8d16d2c1855c498b83ba0c).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by AmplabJenkins <gi...@git.apache.org>.
Github user AmplabJenkins commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93486052
  
    Test FAILed.
    Refer to this link for build results (access rights to CI server needed): 
    https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/30352/
    Test FAILed.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476286
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    +        case "regression" =>
    +          data
    +        case _ =>
    +          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +      }
    +    }
    +    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
    +
    +    (dataframes(0), dataframes(1))
    +  }
    +
    +  def run(params: Params) {
    +    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
    +    val sc = new SparkContext(conf)
    +    params.checkpointDir.foreach(sc.setCheckpointDir)
    +    val algo = params.algo.toLowerCase
    +
    +    println(s"DecisionTreeExample with parameters:\n$params")
    +
    +    // Load training and test data and cache it.
    +    val (training: DataFrame, test: DataFrame) =
    +      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
    +
    +    val numTraining = training.count()
    +    val numTest = test.count()
    +    val numFeatures = training.select("features").first().getAs[Vector](0).size
    +    println("Loaded data:")
    +    println(s"  numTraining = $numTraining, numTest = $numTest")
    +    println(s"  numFeatures = $numFeatures")
    +
    +    // Set up Pipeline
    +    val stages = new mutable.ArrayBuffer[PipelineStage]()
    +    // (1) For classification, re-index classes.
    +    if (algo == "classification") {
    +      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol("label")
    +      stages += labelIndexer
    +    }
    +    // (2) Identify categorical features using VectorIndexer.
    +    //     Features with more than maxCategories values will be treated as continuous.
    +    val featuresIndexer = new VectorIndexer().setInputCol("features")
    +      .setOutputCol("indexedFeatures").setMaxCategories(10)
    +    stages += featuresIndexer
    +    // (3) Learn DecisionTree
    +    val dt = algo match {
    +      case "classification" =>
    +        new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
    +          .setMaxDepth(params.maxDepth)
    +          .setMaxBins(params.maxBins)
    +          .setMinInstancesPerNode(params.minInstancesPerNode)
    +          .setMinInfoGain(params.minInfoGain)
    +          .setCacheNodeIds(params.cacheNodeIds)
    +          .setCheckpointInterval(params.checkpointInterval)
    +      case "regression" =>
    +        new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
    +          .setMaxDepth(params.maxDepth)
    +          .setMaxBins(params.maxBins)
    +          .setMinInstancesPerNode(params.minInstancesPerNode)
    +          .setMinInfoGain(params.minInfoGain)
    +          .setCacheNodeIds(params.cacheNodeIds)
    +          .setCheckpointInterval(params.checkpointInterval)
    +      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +    }
    +    stages += dt
    +    val pipeline = new Pipeline().setStages(stages.toArray)
    +
    +    // Fit the Pipeline
    +    val startTime = System.nanoTime()
    +    val pipelineModel = pipeline.fit(training)
    +    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    +    println(s"Training time: $elapsedTime seconds")
    +
    +    // Get the trained Decision Tree from the fitted PipelineModel
    +    val treeModel: DecisionTreeModel = algo match {
    +      case "classification" =>
    +        pipelineModel.getModel[DecisionTreeClassificationModel](
    --- End diff --
    
    Are the types necessary for compile?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28533086
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    +        case "regression" =>
    +          data
    +        case _ =>
    +          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +      }
    +    }
    +    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
    +
    +    (dataframes(0), dataframes(1))
    +  }
    +
    +  def run(params: Params) {
    +    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
    +    val sc = new SparkContext(conf)
    +    params.checkpointDir.foreach(sc.setCheckpointDir)
    +    val algo = params.algo.toLowerCase
    +
    +    println(s"DecisionTreeExample with parameters:\n$params")
    +
    +    // Load training and test data and cache it.
    +    val (training: DataFrame, test: DataFrame) =
    +      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
    +
    +    val numTraining = training.count()
    +    val numTest = test.count()
    +    val numFeatures = training.select("features").first().getAs[Vector](0).size
    +    println("Loaded data:")
    +    println(s"  numTraining = $numTraining, numTest = $numTest")
    +    println(s"  numFeatures = $numFeatures")
    +
    +    // Set up Pipeline
    +    val stages = new mutable.ArrayBuffer[PipelineStage]()
    +    // (1) For classification, re-index classes.
    +    if (algo == "classification") {
    +      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol("label")
    +      stages += labelIndexer
    +    }
    +    // (2) Identify categorical features using VectorIndexer.
    +    //     Features with more than maxCategories values will be treated as continuous.
    +    val featuresIndexer = new VectorIndexer().setInputCol("features")
    +      .setOutputCol("indexedFeatures").setMaxCategories(10)
    +    stages += featuresIndexer
    +    // (3) Learn DecisionTree
    +    val dt = algo match {
    +      case "classification" =>
    +        new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
    +          .setMaxDepth(params.maxDepth)
    +          .setMaxBins(params.maxBins)
    +          .setMinInstancesPerNode(params.minInstancesPerNode)
    +          .setMinInfoGain(params.minInfoGain)
    +          .setCacheNodeIds(params.cacheNodeIds)
    +          .setCheckpointInterval(params.checkpointInterval)
    +      case "regression" =>
    +        new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
    +          .setMaxDepth(params.maxDepth)
    +          .setMaxBins(params.maxBins)
    +          .setMinInstancesPerNode(params.minInstancesPerNode)
    +          .setMinInfoGain(params.minInfoGain)
    +          .setCacheNodeIds(params.cacheNodeIds)
    +          .setCheckpointInterval(params.checkpointInterval)
    +      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +    }
    +    stages += dt
    +    val pipeline = new Pipeline().setStages(stages.toArray)
    +
    +    // Fit the Pipeline
    +    val startTime = System.nanoTime()
    +    val pipelineModel = pipeline.fit(training)
    +    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    +    println(s"Training time: $elapsedTime seconds")
    +
    +    // Get the trained Decision Tree from the fitted PipelineModel
    +    val treeModel: DecisionTreeModel = algo match {
    +      case "classification" =>
    +        pipelineModel.getModel[DecisionTreeClassificationModel](
    --- End diff --
    
    Yep.  I think it's because "dt" can't be resolved to a concrete type since I use it for both Classifier and Regressor.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476276
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    --- End diff --
    
    Should we use lowercase strings?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28626296
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    --- End diff --
    
    We already use lowercase strings in many places like Statistics and old tree APIs. So it is nice to be consistent. We can change this in a follow-up PR.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-94539212
  
    @jkbradley This was merged. Could you close this PR?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476278
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    --- End diff --
    
    Do you think we should modify `StringIndexer` to convert input column to string type?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28532900
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    --- End diff --
    
    Where is "Cast" defined?  (IntelliJ brings up Catalyst, which I don't want to expose here.)


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476293
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---
    @@ -0,0 +1,155 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
    +import org.apache.spark.ml.impl.tree._
    +import org.apache.spark.ml.param.{Params, ParamMap}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
    + * for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassifier
    +  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
    +  with DecisionTreeParams[DecisionTreeClassifier]
    +  with TreeClassifierParams[DecisionTreeClassifier] {
    +
    +  // Override parameter setters from parent trait for Java API compatibility.
    +
    +  override def setMaxDepth(maxDepth: Int): DecisionTreeClassifier = super.setMaxDepth(maxDepth)
    +
    +  override def setMaxBins(maxBins: Int): DecisionTreeClassifier = super.setMaxBins(maxBins)
    +
    +  override def setMinInstancesPerNode(minInstancesPerNode: Int): DecisionTreeClassifier =
    +    super.setMinInstancesPerNode(minInstancesPerNode)
    +
    +  override def setMinInfoGain(minInfoGain: Double): DecisionTreeClassifier =
    +    super.setMinInfoGain(minInfoGain)
    +
    +  override def setMaxMemoryInMB(maxMemoryInMB: Int): DecisionTreeClassifier =
    +    super.setMaxMemoryInMB(maxMemoryInMB)
    +
    +  override def setCacheNodeIds(cacheNodeIds: Boolean): DecisionTreeClassifier =
    +    super.setCacheNodeIds(cacheNodeIds)
    +
    +  override def setCheckpointInterval(checkpointInterval: Int): DecisionTreeClassifier =
    +    super.setCheckpointInterval(checkpointInterval)
    +
    +  override def setImpurity(impurity: String): DecisionTreeClassifier = super.setImpurity(impurity)
    +
    +  override protected def train(
    +      dataset: DataFrame,
    +      paramMap: ParamMap): DecisionTreeClassificationModel = {
    +    val categoricalFeatures: Map[Int, Int] =
    +      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
    +    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
    +      case Some(n: Int) => n
    +      case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
    +        s" with invalid label column, without the number of classes specified.")
    +        // TODO: Automatically index labels.
    +    }
    +    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
    +    val strategy = getOldStrategy(categoricalFeatures, numClasses)
    +    val oldModel = OldDecisionTree.train(oldDataset, strategy)
    +    DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap)
    +  }
    +
    +  /** Create a Strategy instance to use with the old API. */
    +  override private[ml] def getOldStrategy(
    +      categoricalFeatures: Map[Int, Int],
    +      numClasses: Int): OldStrategy = {
    +    val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
    +    strategy.algo = OldAlgo.Classification
    +    strategy.setImpurity(getOldImpurity)
    +    strategy
    +  }
    +}
    +
    +object DecisionTreeClassifier {
    +  /** Accessor for supported impurities */
    +  final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
    +}
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassificationModel private[ml] (
    +    override val parent: DecisionTreeClassifier,
    +    override val fittingParamMap: ParamMap,
    +    override val rootNode: Node)
    +  extends PredictionModel[Vector, DecisionTreeClassificationModel]
    +  with DecisionTreeModel with Serializable {
    +
    +  require(rootNode != null,
    +    "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
    +
    +  override protected def predict(features: Vector): Double = {
    --- End diff --
    
    Let's mark `DecisionTreeClassificationModel` final. Otherwise, if we make `predict` protected now, changing it to public will be a break change in the future.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487006
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    +
    +  /** Convert to old Split format */
    +  private[tree] def toOld: OldSplit
    +}
    +
    +private[ml] object Split {
    +
    +  def fromOld(oldSplit: OldSplit): Split = {
    +    oldSplit.featureType match {
    +      case OldFeatureType.Categorical =>
    +        new CategoricalSplit(feature = oldSplit.feature,
    +          categories = oldSplit.categories.toSet)
    +      case OldFeatureType.Continuous =>
    +        new ContinuousSplit(feature = oldSplit.feature, threshold = oldSplit.threshold)
    +    }
    +  }
    +
    +}
    +
    +/**
    + * Split which tests a categorical feature.
    + * @param feature  Index of the feature to test
    + * @param categories  If the feature value is in this set of categories, then the split goes left.
    --- End diff --
    
    `categories` -> `leftCategories`?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476277
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    --- End diff --
    
    Set a fixed seed to make the test reproducible.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476290
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---
    @@ -0,0 +1,155 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
    +import org.apache.spark.ml.impl.tree._
    +import org.apache.spark.ml.param.{Params, ParamMap}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
    + * for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassifier
    --- End diff --
    
    `final`?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93589748
  
      [Test build #681 has started](https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/681/consoleFull) for   PR 5530 at commit [`5626c81`](https://github.com/apache/spark/commit/5626c81cc896a04fc109d09ee145487c116ecf6d).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487011
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    +
    +  /** Convert to old Split format */
    +  private[tree] def toOld: OldSplit
    +}
    +
    +private[ml] object Split {
    +
    +  def fromOld(oldSplit: OldSplit): Split = {
    +    oldSplit.featureType match {
    +      case OldFeatureType.Categorical =>
    +        new CategoricalSplit(feature = oldSplit.feature,
    +          categories = oldSplit.categories.toSet)
    +      case OldFeatureType.Continuous =>
    +        new ContinuousSplit(feature = oldSplit.feature, threshold = oldSplit.threshold)
    +    }
    +  }
    +
    +}
    +
    +/**
    + * Split which tests a categorical feature.
    + * @param feature  Index of the feature to test
    + * @param categories  If the feature value is in this set of categories, then the split goes left.
    + *                    Otherwise, it goes right.
    + */
    +class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split {
    --- End diff --
    
    `Set` is not Java friendly. We can use `Array[Double]` in the constructor and save `Set[Double]` internally. Btw, do we want to optimize the storage based on which side has less number of categories? If yes, we can have `categories: Array[Double], isLeft: Boolean`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28626395
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    +        case "regression" =>
    +          data
    +        case _ =>
    +          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +      }
    +    }
    +    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
    +
    +    (dataframes(0), dataframes(1))
    +  }
    +
    +  def run(params: Params) {
    +    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
    +    val sc = new SparkContext(conf)
    +    params.checkpointDir.foreach(sc.setCheckpointDir)
    +    val algo = params.algo.toLowerCase
    +
    +    println(s"DecisionTreeExample with parameters:\n$params")
    +
    +    // Load training and test data and cache it.
    +    val (training: DataFrame, test: DataFrame) =
    +      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
    +
    +    val numTraining = training.count()
    +    val numTest = test.count()
    +    val numFeatures = training.select("features").first().getAs[Vector](0).size
    +    println("Loaded data:")
    +    println(s"  numTraining = $numTraining, numTest = $numTest")
    +    println(s"  numFeatures = $numFeatures")
    +
    +    // Set up Pipeline
    +    val stages = new mutable.ArrayBuffer[PipelineStage]()
    +    // (1) For classification, re-index classes.
    +    if (algo == "classification") {
    +      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol("label")
    +      stages += labelIndexer
    +    }
    +    // (2) Identify categorical features using VectorIndexer.
    +    //     Features with more than maxCategories values will be treated as continuous.
    +    val featuresIndexer = new VectorIndexer().setInputCol("features")
    --- End diff --
    
    Yes.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28626572
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    +
    +  /** Convert to old Split format */
    +  private[tree] def toOld: OldSplit
    +}
    +
    +private[ml] object Split {
    +
    +  def fromOld(oldSplit: OldSplit): Split = {
    +    oldSplit.featureType match {
    +      case OldFeatureType.Categorical =>
    +        new CategoricalSplit(feature = oldSplit.feature,
    +          categories = oldSplit.categories.toSet)
    +      case OldFeatureType.Continuous =>
    +        new ContinuousSplit(feature = oldSplit.feature, threshold = oldSplit.threshold)
    +    }
    +  }
    +
    +}
    +
    +/**
    + * Split which tests a categorical feature.
    + * @param feature  Index of the feature to test
    + * @param categories  If the feature value is in this set of categories, then the split goes left.
    + *                    Otherwise, it goes right.
    + */
    +class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split {
    --- End diff --
    
    Sounds good.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93658109
  
    @jkbradley I made one pass. Had some minor comments about the method scopes and naming. It looks good in general.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28486998
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    --- End diff --
    
    `feature` -> `featureIndex`?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487029
  
    --- Diff: mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala ---
    @@ -0,0 +1,274 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.scalatest.FunSuite
    +
    +import org.apache.spark.ml.impl.TreeTests
    +import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
    +  DecisionTreeSuite => OldDecisionTreeSuite}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
    +
    +  import DecisionTreeClassifierSuite.compareAPIs
    +
    +  private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
    +  private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _
    +  private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _
    +  private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
    +  private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
    +  private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
    +
    +  override def beforeAll() {
    +    super.beforeAll()
    +    categoricalDataPointsRDD =
    +      sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
    +    orderedLabeledPointsWithLabel0RDD =
    +      sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0())
    +    orderedLabeledPointsWithLabel1RDD =
    +      sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1())
    +    categoricalDataPointsForMulticlassRDD =
    +      sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass())
    +    continuousDataPointsForMulticlassRDD =
    +      sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass())
    +    categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
    +      OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
    +  }
    +
    +  /////////////////////////////////////////////////////////////////////////////
    +  // Tests calling train()
    +  /////////////////////////////////////////////////////////////////////////////
    +
    +  test("Binary classification stump with ordered categorical features") {
    +    val dt = new DecisionTreeClassifier()
    +      .setImpurity("gini")
    +      .setMaxDepth(2)
    +      .setMaxBins(100)
    +    val categoricalFeatures = Map(0 -> 3, 1-> 3)
    +    val numClasses = 2
    +    compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
    +  }
    +
    +  test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
    +    val dt = new DecisionTreeClassifier()
    +      .setMaxDepth(3)
    +      .setMaxBins(100)
    +    val numClasses = 2
    +    Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd =>
    +      DecisionTreeClassifier.supportedImpurities.foreach { impurity =>
    +        dt.setImpurity(impurity)
    +        compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
    +      }
    +    }
    +  }
    +
    +  test("Multiclass classification stump with 3-ary (unordered) categorical features") {
    +    val rdd = categoricalDataPointsForMulticlassRDD
    +    val dt = new DecisionTreeClassifier()
    +      .setImpurity("Gini")
    +      .setMaxDepth(4)
    +    val numClasses = 3
    +    val categoricalFeatures = Map(0 -> 3, 1 -> 3)
    +    compareAPIs(rdd, dt, categoricalFeatures, numClasses)
    +  }
    +
    +  test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
    +    arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
    --- End diff --
    
    ~~~scala
    val arr = Array(
      new LabelPoint(...),
      new ...
    }
    ~~~


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487024
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala ---
    @@ -0,0 +1,300 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.impl.estimator.PredictorParams
    +import org.apache.spark.ml.param._
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
    +  Impurity => OldImpurity, Variance => OldVariance}
    +
    +
    +/**
    + * :: DeveloperApi ::
    + * Parameters for Decision Tree-based algorithms.
    + * @tparam M  Concrete class implementing this parameter trait
    + */
    +@DeveloperApi
    +private[ml] trait DecisionTreeParams[M] extends PredictorParams {
    +
    +  /**
    +   * Maximum depth of the tree.
    +   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
    +   * (default = 5)
    +   * @group param
    +   */
    +  val maxDepth: IntParam =
    +    new IntParam(this, "maxDepth", "Maximum depth of the tree." +
    +      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
    +
    +  /**
    +   * Maximum number of bins used for discretizing continuous features and for choosing how to split
    +   * on features at each node.  More bins give higher granularity.
    +   * Must be >= 2 and >= number of categories in any categorical feature.
    +   * (default = 32)
    +   * @group param
    +   */
    +  val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for discretizing" +
    +    " continuous features.  Must be >=2 and >= number of categories for any categorical feature.")
    +
    +  /**
    +   * Minimum number of instances each child must have after split.
    +   * If a split cause left or right child to have less than minInstancesPerNode,
    +   * this split will not be considered as a valid split.
    +   * Should be >= 1.
    +   * (default = 1)
    +   * @group param
    +   */
    +  val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum number" +
    +    " of instances each child must have after split.  If a split cause left or right child to" +
    +    " have less than minInstancesPerNode, this split will not be considered as a valid split." +
    +    " Should be >= 1.")
    +
    +  /**
    +   * Minimum information gain for a split to be considered at a tree node.
    +   * (default = 0.0)
    +   * @group param
    +   */
    +  val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
    +    "Minimum information gain for a split to be considered at a tree node.")
    +
    +  /**
    +   * Maximum memory in MB allocated to histogram aggregation.
    +   * (default = 256 MB)
    +   * @group expertParam
    +   */
    +  val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
    +    "Maximum memory in MB allocated to histogram aggregation.")
    +
    +  /**
    +   * If false, the algorithm will pass trees to executors to match instances with nodes.
    +   * If true, the algorithm will cache node IDs for each instance.
    +   * Caching can speed up training of deeper trees.
    +   * (default = false)
    +   * @group expertParam
    +   */
    +  val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
    +    " algorithm will pass trees to executors to match instances with nodes. If true, the" +
    +    " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
    +    " trees.")
    +
    +  /**
    +   * Specifies how often to checkpoint the cached node IDs.
    +   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
    +   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
    +   * [[org.apache.spark.SparkContext]].
    +   * Must be >= 1.
    +   * (default = 10)
    +   * @group expertParam
    +   */
    +  val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies how" +
    +    " often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
    +    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
    +    " checkpoint directory is set in the SparkContext. Must be >= 1.")
    +
    +  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
    +    maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
    +
    +  /** @group setParam */
    +  def setMaxDepth(value: Int): M = {
    +    require(value >= 0, s"maxDepth parameter must be >= 0.  Given bad value: $value")
    +    set(maxDepth, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group getParam */
    +  def getMaxDepth: Int = getOrDefault(maxDepth)
    +
    +  /** @group setParam */
    +  def setMaxBins(value: Int): M = {
    +    require(value >= 2, s"maxBins parameter must be >= 2.  Given bad value: $value")
    +    set(maxBins, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group getParam */
    +  def getMaxBins: Int = getOrDefault(maxBins)
    +
    +  /** @group setParam */
    +  def setMinInstancesPerNode(value: Int): M = {
    +    require(value >= 1, s"minInstancesPerNode parameter must be >= 1.  Given bad value: $value")
    +    set(minInstancesPerNode, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group getParam */
    +  def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
    +
    +  /** @group setParam */
    +  def setMinInfoGain(value: Double): M = {
    +    set(minInfoGain, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group getParam */
    +  def getMinInfoGain: Double = getOrDefault(minInfoGain)
    +
    +  /** @group expertSetParam */
    +  def setMaxMemoryInMB(value: Int): M = {
    +    require(value > 0, s"maxMemoryInMB parameter must be > 0.  Given bad value: $value")
    +    set(maxMemoryInMB, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group expertGetParam */
    +  def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
    +
    +  /** @group expertSetParam */
    +  def setCacheNodeIds(value: Boolean): M = {
    +    set(cacheNodeIds, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group expertGetParam */
    +  def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
    +
    +  /** @group expertSetParam */
    +  def setCheckpointInterval(value: Int): M = {
    +    require(value >= 1, s"checkpointInterval parameter must be >= 1.  Given bad value: $value")
    +    set(checkpointInterval, value)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group expertGetParam */
    +  def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
    +
    +  /**
    +   * Create a Strategy instance to use with the old API.
    +   * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
    +   *       the default for single trees).
    +   */
    +  private[ml] def getOldStrategy(
    +      categoricalFeatures: Map[Int, Int],
    +      numClasses: Int): OldStrategy = {
    +    val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
    +    strategy.checkpointInterval = getCheckpointInterval
    +    strategy.maxBins = getMaxBins
    +    strategy.maxDepth = getMaxDepth
    +    strategy.maxMemoryInMB = getMaxMemoryInMB
    +    strategy.minInfoGain = getMinInfoGain
    +    strategy.minInstancesPerNode = getMinInstancesPerNode
    +    strategy.useNodeIdCache = getCacheNodeIds
    +    strategy.numClasses = numClasses
    +    strategy.categoricalFeaturesInfo = categoricalFeatures
    +    strategy.subsamplingRate = 1.0 // default for individual trees
    +    strategy
    +  }
    +}
    +
    +/**
    + * (private trait) Parameters for Decision Tree-based classification algorithms.
    + * @tparam M  Concrete class implementing this parameter trait
    + */
    +private[ml] trait TreeClassifierParams[M] extends Params {
    +
    +  /**
    +   * Criterion used for information gain calculation (case-insensitive).
    +   * Supported: "entropy" and "gini".
    +   * (default = gini)
    +   * @group param
    +   */
    +  val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
    +    " information gain calculation (case-insensitive). Supported options:" +
    +    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
    +
    +  setDefault(impurity -> "gini")
    +
    +  /** @group setParam */
    +  def setImpurity(value: String): M = {
    +    val impurityStr = value.toLowerCase
    +    require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
    +      s"Tree-based classifier was given unrecognized impurity: $value." +
    +      s"  Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
    +    set(impurity, impurityStr)
    +    this.asInstanceOf[M]
    +  }
    +
    +  /** @group getParam */
    +  def getImpurity: String = getOrDefault(impurity)
    +
    +  /** Convert new impurity to old impurity. */
    +  protected def getOldImpurity: OldImpurity = {
    --- End diff --
    
    Should be package private.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93579127
  
      [Test build #30371 has finished](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30371/consoleFull) for   PR 5530 at commit [`5626c81`](https://github.com/apache/spark/commit/5626c81cc896a04fc109d09ee145487c116ecf6d).
     * This patch **fails PySpark unit tests**.
     * This patch merges cleanly.
     * This patch adds the following public classes _(experimental)_:
      * `  case class Params(`
      * `sealed trait Node extends Serializable `
      * `trait Split extends Serializable `
      * `class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split `
      * `class ContinuousSplit(override val feature: Int, val threshold: Double) extends Split `
      * `trait DecisionTreeModel `
    
     * This patch does not change any dependencies.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28532385
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    --- End diff --
    
    I made it case-insensitive actually


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476283
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    +        case "regression" =>
    +          data
    +        case _ =>
    +          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    +      }
    +    }
    +    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
    +
    +    (dataframes(0), dataframes(1))
    +  }
    +
    +  def run(params: Params) {
    +    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
    +    val sc = new SparkContext(conf)
    +    params.checkpointDir.foreach(sc.setCheckpointDir)
    +    val algo = params.algo.toLowerCase
    +
    +    println(s"DecisionTreeExample with parameters:\n$params")
    +
    +    // Load training and test data and cache it.
    +    val (training: DataFrame, test: DataFrame) =
    +      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
    +
    +    val numTraining = training.count()
    +    val numTest = test.count()
    +    val numFeatures = training.select("features").first().getAs[Vector](0).size
    +    println("Loaded data:")
    +    println(s"  numTraining = $numTraining, numTest = $numTest")
    +    println(s"  numFeatures = $numFeatures")
    +
    +    // Set up Pipeline
    +    val stages = new mutable.ArrayBuffer[PipelineStage]()
    +    // (1) For classification, re-index classes.
    +    if (algo == "classification") {
    +      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol("label")
    +      stages += labelIndexer
    +    }
    +    // (2) Identify categorical features using VectorIndexer.
    +    //     Features with more than maxCategories values will be treated as continuous.
    +    val featuresIndexer = new VectorIndexer().setInputCol("features")
    --- End diff --
    
    It may read better if we chop down the setters.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28626389
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    --- End diff --
    
    Sounds good. I think we should convert to input type to strings first and store the string labels instead of arbitrary type, just to easy model export/import. Let's make a JIRA and discuss this.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28535102
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    +
    +  /** Convert to old Split format */
    +  private[tree] def toOld: OldSplit
    +}
    +
    +private[ml] object Split {
    +
    +  def fromOld(oldSplit: OldSplit): Split = {
    +    oldSplit.featureType match {
    +      case OldFeatureType.Categorical =>
    +        new CategoricalSplit(feature = oldSplit.feature,
    +          categories = oldSplit.categories.toSet)
    +      case OldFeatureType.Continuous =>
    +        new ContinuousSplit(feature = oldSplit.feature, threshold = oldSplit.threshold)
    +    }
    +  }
    +
    +}
    +
    +/**
    + * Split which tests a categorical feature.
    + * @param feature  Index of the feature to test
    + * @param categories  If the feature value is in this set of categories, then the split goes left.
    + *                    Otherwise, it goes right.
    + */
    +class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split {
    --- End diff --
    
    I was thinking more about optimizing running time (if there were thousands of categories), rather than storage.  I can make it an Array and add a to-do for sorting the Array and using binary search.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley closed the pull request at:

    https://github.com/apache/spark/pull/5530


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28487001
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    --- End diff --
    
    It returns a boolean. So it might be better to call it `shouldGoLeft`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93864879
  
      [Test build #30439 has finished](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30439/consoleFull) for   PR 5530 at commit [`6aae255`](https://github.com/apache/spark/commit/6aae25587cdcadc0e5d68078ca77d0cdee59e6e4).
     * This patch **passes all tests**.
     * This patch merges cleanly.
     * This patch adds the following public classes _(experimental)_:
      * `  case class Params(`
      * `sealed abstract class Node extends Serializable `
      * `sealed trait Split extends Serializable `
      * `final class CategoricalSplit(`
      * `final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split `
      * `trait DecisionTreeModel `
    
     * This patch does not change any dependencies.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476296
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Node.scala ---
    @@ -0,0 +1,201 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    --- End diff --
    
    `impl` is usually used for private code or implementation of a public interface. So we should either move `Node` to `ml.tree` and keep the subclasses under `impl` (then we cannot use sealed), or move everything to `ml.tree`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28486997
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    --- End diff --
    
    `sealed`?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28533831
  
    --- Diff: mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
    @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
         assert(bins(0).length === 0)
       }
     
    +  test("Avoid aggregation on the last level") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue leaf nodes into node queue
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Avoid aggregation if impurity is 0.0") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue a node into node queue if its impurity is 0.0
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Second level node building with vs. without groups") {
    +    val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
    +    assert(arr.length === 1000)
    +    val rdd = sc.parallelize(arr)
    +    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
    +    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
    +    assert(splits.length === 2)
    +    assert(splits(0).length === 99)
    +    assert(bins.length === 2)
    +    assert(bins(0).length === 100)
    +
    +    // Train a 1-node model
    +    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
    +      numClasses = 2, maxBins = 100)
    +    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
    +    val rootNode1 = modelOneNode.topNode.deepCopy()
    +    val rootNode2 = modelOneNode.topNode.deepCopy()
    +    assert(rootNode1.leftNode.nonEmpty)
    +    assert(rootNode1.rightNode.nonEmpty)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    // Single group second level tree construction.
    +    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
    +      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +    val children1 = new Array[Node](2)
    +    children1(0) = rootNode1.leftNode.get
    +    children1(1) = rootNode1.rightNode.get
    +
    +    // Train one second-level node at a time.
    +    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
    +    val treeToNodeToIndexInfoA = Map((0, Map(
    +      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
    +    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
    +    val treeToNodeToIndexInfoB = Map((0, Map(
    +      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
    +    val children2 = new Array[Node](2)
    +    children2(0) = rootNode2.leftNode.get
    +    children2(1) = rootNode2.rightNode.get
    +
    +    // Verify whether the splits obtained using single group and multiple group level
    +    // construction strategies are the same.
    +    for (i <- 0 until 2) {
    +      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
    +      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
    +      assert(children1(i).split === children2(i).split)
    +      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
    +      val stats1 = children1(i).stats.get
    +      val stats2 = children2(i).stats.get
    +      assert(stats1.gain === stats2.gain)
    +      assert(stats1.impurity === stats2.impurity)
    +      assert(stats1.leftImpurity === stats2.leftImpurity)
    +      assert(stats1.rightImpurity === stats2.rightImpurity)
    +      assert(children1(i).predict.predict === children2(i).predict.predict)
    +    }
    +  }
    --- End diff --
    
    Yep, it's organized so it's easy to match with the equivalent tests in the new API


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28486992
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Node.scala ---
    @@ -0,0 +1,201 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
    +  Node => OldNode, Predict => OldPredict}
    +
    +
    +/**
    + * Decision tree node interface.
    + */
    +sealed trait Node extends Serializable {
    +
    +  // TODO: Add aggregate stats (once available).  This will happen after we move the DecisionTree
    +  //       code into the new API and deprecate the old API.
    +
    +  /** Prediction this node makes (or would make, if it is an internal node) */
    +  def prediction: Double
    +
    +  /** Impurity measure at this node (for training data) */
    +  def impurity: Double
    +
    +  /** Recursive prediction helper method */
    +  private[ml] def predict(features: Vector): Double = prediction
    +
    +  /**
    +   * Get the number of nodes in tree below this node, including leaf nodes.
    +   * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 2.
    +   */
    +  private[tree] def numDescendants: Int
    +
    +  /**
    +   * Recursive print function.
    +   * @param indentFactor  The number of spaces to add to each level of indentation.
    +   */
    +  private[tree] def subtreeToString(indentFactor: Int = 0): String
    +
    +  /**
    +   * Get depth of tree from this node.
    +   * E.g.: Depth 0 means this is a leaf node.  Depth 1 means 1 internal and 2 leaf nodes.
    +   */
    +  private[tree] def subtreeDepth: Int
    +
    +  /**
    +   * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
    +   * @param id  Node ID using old format IDs
    +   */
    +  private[ml] def toOld(id: Int): OldNode
    +}
    +
    +private[ml] object Node {
    +
    +  /**
    +   * Create a new Node from the old Node format, recursively creating child nodes as needed.
    +   */
    +  def fromOld(oldNode: OldNode): Node = {
    +    if (oldNode.isLeaf) {
    +      // TODO: Once the implementation has been moved to this API, then include sufficient
    +      //       statistics here.
    +      new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
    +    } else {
    +      val gain = if (oldNode.stats.nonEmpty) {
    +        oldNode.stats.get.gain
    +      } else {
    +        0.0
    +      }
    +      new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
    +        gain = gain, leftChild = fromOld(oldNode.leftNode.get),
    +        rightChild = fromOld(oldNode.rightNode.get), split = Split.fromOld(oldNode.split.get))
    +    }
    +  }
    +
    +  /**
    +   * Helper method for [[Node.subtreeToString()]].
    +   * @param split  Split to print
    +   * @param left  Indicates whether this is the part of the split going to the left,
    +   *              or that going to the right.
    +   */
    +  private[tree] def splitToString(split: Split, left: Boolean): String = {
    --- End diff --
    
    Should be `private` instead of `private[tree] since this is only called by `subtreeToString`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93486037
  
      [Test build #30352 has finished](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30352/consoleFull) for   PR 5530 at commit [`f8fbd24`](https://github.com/apache/spark/commit/f8fbd24877c522138f8d16d2c1855c498b83ba0c).
     * This patch **fails Scala style tests**.
     * This patch merges cleanly.
     * This patch adds the following public classes _(experimental)_:
      * `  case class Params(`
      * `sealed trait Node extends Serializable `
      * `trait Split extends Serializable `
      * `class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split `
      * `class ContinuousSplit(override val feature: Int, val threshold: Double) extends Split `
      * `trait DecisionTreeModel `
    
     * This patch does not change any dependencies.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by AmplabJenkins <gi...@git.apache.org>.
Github user AmplabJenkins commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93579142
  
    Test FAILed.
    Refer to this link for build results (access rights to CI server needed): 
    https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/30371/
    Test FAILed.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28537201
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Split.scala ---
    @@ -0,0 +1,105 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    +
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
    +import org.apache.spark.mllib.tree.model.{Split => OldSplit}
    +
    +
    +/**
    + * Interface for a "Split," which specifies a test made at a decision tree node
    + * to choose the left or right path.
    + */
    +trait Split extends Serializable {
    +
    +  /** Index of feature which this split tests */
    +  def feature: Int
    +
    +  /** Return true (split to left) or false (split to right) */
    +  private[ml] def goLeft(features: Vector): Boolean
    +
    +  /** Convert to old Split format */
    +  private[tree] def toOld: OldSplit
    +}
    +
    +private[ml] object Split {
    +
    +  def fromOld(oldSplit: OldSplit): Split = {
    +    oldSplit.featureType match {
    +      case OldFeatureType.Categorical =>
    +        new CategoricalSplit(feature = oldSplit.feature,
    +          categories = oldSplit.categories.toSet)
    +      case OldFeatureType.Continuous =>
    +        new ContinuousSplit(feature = oldSplit.feature, threshold = oldSplit.threshold)
    +    }
    +  }
    +
    +}
    +
    +/**
    + * Split which tests a categorical feature.
    + * @param feature  Index of the feature to test
    + * @param categories  If the feature value is in this set of categories, then the split goes left.
    + *                    Otherwise, it goes right.
    + */
    +class CategoricalSplit(override val feature: Int, val categories: Set[Double]) extends Split {
    --- End diff --
    
    I ended up going ahead and doing this to-do, so that part of the code will change a bit.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-94065922
  
    LGTM. Merged into master. I left the JIRA open for ensembles.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28532592
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    --- End diff --
    
    Yeah, I like that idea.  That way it works for any basic type (as I had wanted), but it doesn't require fancy implementation.  I'll do a follow-up PR for that.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93563602
  
      [Test build #30371 has started](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30371/consoleFull) for   PR 5530 at commit [`5626c81`](https://github.com/apache/spark/commit/5626c81cc896a04fc109d09ee145487c116ecf6d).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by SparkQA <gi...@git.apache.org>.
Github user SparkQA commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93853588
  
      [Test build #30439 has started](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/30439/consoleFull) for   PR 5530 at commit [`6aae255`](https://github.com/apache/spark/commit/6aae25587cdcadc0e5d68078ca77d0cdee59e6e4).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28476280
  
    --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala ---
    @@ -0,0 +1,322 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.examples.ml
    +
    +import scala.collection.mutable
    +import scala.language.reflectiveCalls
    +
    +import scopt.OptionParser
    +
    +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.examples.mllib.AbstractParams
    +import org.apache.spark.ml.{Pipeline, PipelineStage}
    +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
    +import org.apache.spark.ml.impl.tree.DecisionTreeModel
    +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.StringType
    +import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.sql.functions.callUDF
    +
    +
    +/**
    + * An example runner for decision trees and random forests. Run with
    + * {{{
    + * ./bin/run-example ml.DecisionTreeExample [options]
    + * }}}
    + * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    + */
    +object DecisionTreeExample {
    +
    +  case class Params(
    +      input: String = null,
    +      testInput: String = "",
    +      dataFormat: String = "libsvm",
    +      algo: String = "Classification",
    +      maxDepth: Int = 5,
    +      maxBins: Int = 32,
    +      minInstancesPerNode: Int = 1,
    +      minInfoGain: Double = 0.0,
    +      numTrees: Int = 1,
    +      featureSubsetStrategy: String = "auto",
    +      fracTest: Double = 0.2,
    +      cacheNodeIds: Boolean = false,
    +      checkpointDir: Option[String] = None,
    +      checkpointInterval: Int = 10) extends AbstractParams[Params]
    +
    +  def main(args: Array[String]) {
    +    val defaultParams = Params()
    +
    +    val parser = new OptionParser[Params]("DecisionTreeExample") {
    +      head("DecisionTreeExample: an example decision tree app.")
    +      opt[String]("algo")
    +        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
    +        .action((x, c) => c.copy(algo = x))
    +      opt[Int]("maxDepth")
    +        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
    +        .action((x, c) => c.copy(maxDepth = x))
    +      opt[Int]("maxBins")
    +        .text(s"max number of bins, default: ${defaultParams.maxBins}")
    +        .action((x, c) => c.copy(maxBins = x))
    +      opt[Int]("minInstancesPerNode")
    +        .text(s"min number of instances required at child nodes to create the parent split," +
    +          s" default: ${defaultParams.minInstancesPerNode}")
    +        .action((x, c) => c.copy(minInstancesPerNode = x))
    +      opt[Double]("minInfoGain")
    +        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
    +        .action((x, c) => c.copy(minInfoGain = x))
    +      opt[Double]("fracTest")
    +        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
    +          s"this option is ignored. default: ${defaultParams.fracTest}")
    +        .action((x, c) => c.copy(fracTest = x))
    +      opt[Boolean]("cacheNodeIds")
    +        .text(s"whether to use node Id cache during training, " +
    +          s"default: ${defaultParams.cacheNodeIds}")
    +        .action((x, c) => c.copy(cacheNodeIds = x))
    +      opt[String]("checkpointDir")
    +        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
    +         s"default: ${defaultParams.checkpointDir match {
    +           case Some(strVal) => strVal
    +           case None => "None"
    +         }}")
    +        .action((x, c) => c.copy(checkpointDir = Some(x)))
    +      opt[Int]("checkpointInterval")
    +        .text(s"how often to checkpoint the node Id cache, " +
    +         s"default: ${defaultParams.checkpointInterval}")
    +        .action((x, c) => c.copy(checkpointInterval = x))
    +      opt[String]("testInput")
    +        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
    +          s" default: ${defaultParams.testInput}")
    +        .action((x, c) => c.copy(testInput = x))
    +      opt[String]("<dataFormat>")
    +        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
    +        .action((x, c) => c.copy(dataFormat = x))
    +      arg[String]("<input>")
    +        .text("input path to labeled examples")
    +        .required()
    +        .action((x, c) => c.copy(input = x))
    +      checkConfig { params =>
    +        if (params.fracTest < 0 || params.fracTest > 1) {
    +          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
    +        } else {
    +          success
    +        }
    +      }
    +    }
    +
    +    parser.parse(args, defaultParams).map { params =>
    +      run(params)
    +    }.getOrElse {
    +      sys.exit(1)
    +    }
    +  }
    +
    +  /** Load a dataset from the given path, using the given format */
    +  private[ml] def loadData(
    +      sc: SparkContext,
    +      path: String,
    +      format: String,
    +      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
    +    format match {
    +      case "dense" => MLUtils.loadLabeledPoints(sc, path)
    +      case "libsvm" => expectedNumFeatures match {
    +        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
    +        case None => MLUtils.loadLibSVMFile(sc, path)
    +      }
    +      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
    +    }
    +  }
    +
    +  /**
    +   * Load training and test data from files.
    +   * @param input  Path to input dataset.
    +   * @param dataFormat  "libsvm" or "dense"
    +   * @param testInput  Path to test dataset.
    +   * @param algo  Classification or Regression
    +   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
    +   * @return  (training dataset, test dataset)
    +   */
    +  private[ml] def loadDatasets(
    +      sc: SparkContext,
    +      input: String,
    +      dataFormat: String,
    +      testInput: String,
    +      algo: String,
    +      fracTest: Double): (DataFrame, DataFrame) = {
    +    val sqlContext = new SQLContext(sc)
    +    import sqlContext.implicits._
    +
    +    // Load training data
    +    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
    +
    +    // Load or create test set
    +    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
    +      // Load testInput.
    +      val numFeatures = origExamples.take(1)(0).features.size
    +      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
    +      Array(origExamples, origTestExamples)
    +    } else {
    +      // Split input into training, test.
    +      origExamples.randomSplit(Array(1.0 - fracTest, fracTest))
    +    }
    +
    +    // For classification, convert labels to Strings since we will index them later with
    +    // StringIndexer.
    +    def labelsToStrings(data: DataFrame): DataFrame = {
    +      algo.toLowerCase match {
    +        case "classification" =>
    +          val convertToString: Double => String = (label: Double) => label.toString
    +          data.select(
    +            callUDF(convertToString, StringType, data("label")).as("labelString"), data("features"))
    --- End diff --
    
    ~~~
    data.withColumn("labelString", Cast(col("label"), StringType))
    ~~~


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28533775
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/Node.scala ---
    @@ -0,0 +1,201 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.impl.tree
    --- End diff --
    
    I'll move stuff to ml.tree.  I'll also make as many classes final as possible.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by AmplabJenkins <gi...@git.apache.org>.
Github user AmplabJenkins commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93864891
  
    Test PASSed.
    Refer to this link for build results (access rights to CI server needed): 
    https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/30439/
    Test PASSed.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93589652
  
    Unrelated failure...retesting


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28486987
  
    --- Diff: mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
    @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
         assert(bins(0).length === 0)
       }
     
    +  test("Avoid aggregation on the last level") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue leaf nodes into node queue
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Avoid aggregation if impurity is 0.0") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue a node into node queue if its impurity is 0.0
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Second level node building with vs. without groups") {
    +    val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
    +    assert(arr.length === 1000)
    +    val rdd = sc.parallelize(arr)
    +    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
    +    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
    +    assert(splits.length === 2)
    +    assert(splits(0).length === 99)
    +    assert(bins.length === 2)
    +    assert(bins(0).length === 100)
    +
    +    // Train a 1-node model
    +    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
    +      numClasses = 2, maxBins = 100)
    +    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
    +    val rootNode1 = modelOneNode.topNode.deepCopy()
    +    val rootNode2 = modelOneNode.topNode.deepCopy()
    +    assert(rootNode1.leftNode.nonEmpty)
    +    assert(rootNode1.rightNode.nonEmpty)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    // Single group second level tree construction.
    +    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
    +      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +    val children1 = new Array[Node](2)
    +    children1(0) = rootNode1.leftNode.get
    +    children1(1) = rootNode1.rightNode.get
    +
    +    // Train one second-level node at a time.
    +    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
    +    val treeToNodeToIndexInfoA = Map((0, Map(
    +      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
    +    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
    +    val treeToNodeToIndexInfoB = Map((0, Map(
    +      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
    +    val children2 = new Array[Node](2)
    +    children2(0) = rootNode2.leftNode.get
    +    children2(1) = rootNode2.rightNode.get
    +
    +    // Verify whether the splits obtained using single group and multiple group level
    +    // construction strategies are the same.
    +    for (i <- 0 until 2) {
    +      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
    +      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
    +      assert(children1(i).split === children2(i).split)
    +      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
    +      val stats1 = children1(i).stats.get
    +      val stats2 = children2(i).stats.get
    +      assert(stats1.gain === stats2.gain)
    +      assert(stats1.impurity === stats2.impurity)
    +      assert(stats1.leftImpurity === stats2.leftImpurity)
    +      assert(stats1.rightImpurity === stats2.rightImpurity)
    +      assert(children1(i).predict.predict === children2(i).predict.predict)
    +    }
    +  }
    --- End diff --
    
    Could you keep the origin ordering of tests? It helps diff.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on the pull request:

    https://github.com/apache/spark/pull/5530#issuecomment-93853298
  
    @mengxr  I believe I've addressed everything.  I made some responses to your inline comments which may have been hidden.  Thanks for reviewing!


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28537606
  
    --- Diff: mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java ---
    @@ -0,0 +1,97 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification;
    +
    +import java.io.File;
    +import java.io.Serializable;
    +import java.util.HashMap;
    +import java.util.Map;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.ml.impl.TreeTests;
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite;
    +import org.apache.spark.mllib.regression.LabeledPoint;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.util.Utils;
    +
    +
    +public class JavaDecisionTreeClassifierSuite implements Serializable {
    +
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void runDT() {
    +    int nPoints = 20;
    +    double A = 2.0;
    +    double B = -1.5;
    +
    +    JavaRDD<LabeledPoint> data = sc.parallelize(
    +        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
    +    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
    +    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
    +
    +    DecisionTreeClassifier dt = new DecisionTreeClassifier()
    +      .setMaxDepth(2)
    +      .setMaxBins(10)
    +      .setMinInstancesPerNode(5)
    +      .setMinInfoGain(0.0)
    +      .setMaxMemoryInMB(256)
    +      .setCacheNodeIds(false)
    +      .setCheckpointInterval(10)
    +      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
    +    for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
    +      dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
    +    }
    +    DecisionTreeClassificationModel model = dt.fit(dataFrame);
    --- End diff --
    
    I was really just testing the setter methods.  I figure the actual impurities should be tested in Scala.  I'll add a note.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by jkbradley <gi...@git.apache.org>.
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28555321
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---
    @@ -0,0 +1,155 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
    +import org.apache.spark.ml.impl.tree._
    +import org.apache.spark.ml.param.{Params, ParamMap}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
    + * for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassifier
    +  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
    +  with DecisionTreeParams[DecisionTreeClassifier]
    +  with TreeClassifierParams[DecisionTreeClassifier] {
    +
    +  // Override parameter setters from parent trait for Java API compatibility.
    +
    +  override def setMaxDepth(maxDepth: Int): DecisionTreeClassifier = super.setMaxDepth(maxDepth)
    --- End diff --
    
    Thanks, "value" was an oversight.
    
    I'll update using this.type.  I checked the Java docs, and I noticed they don't inherit any parameter documentation.  Do you know if there's a way to make that happen?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API

Posted by mengxr <gi...@git.apache.org>.
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28626488
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---
    @@ -0,0 +1,155 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
    +import org.apache.spark.ml.impl.tree._
    +import org.apache.spark.ml.param.{Params, ParamMap}
    +import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +
    +/**
    + * :: AlphaComponent ::
    + *
    + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
    + * for classification.
    + * It supports both binary and multiclass labels, as well as both continuous and categorical
    + * features.
    + */
    +@AlphaComponent
    +class DecisionTreeClassifier
    +  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
    +  with DecisionTreeParams[DecisionTreeClassifier]
    +  with TreeClassifierParams[DecisionTreeClassifier] {
    +
    +  // Override parameter setters from parent trait for Java API compatibility.
    +
    +  override def setMaxDepth(maxDepth: Int): DecisionTreeClassifier = super.setMaxDepth(maxDepth)
    --- End diff --
    
    I don't know ...


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org