You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2018/05/09 04:21:02 UTC
spark git commit: [SPARK-24132][ML] Instrumentation improvement for
classification
Repository: spark
Updated Branches:
refs/heads/master 9498e528d -> 7e7350285
[SPARK-24132][ML] Instrumentation improvement for classification
## What changes were proposed in this pull request?
- Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier
- Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes
- Modify the instrumentation creation in ml.classification.LinearSVC
- Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC
## How was this patch tested?
Manual.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Lu WANG <lu...@databricks.com>
Closes #21204 from ludatabricks/SPARK-23686.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e735028
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e735028
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e735028
Branch: refs/heads/master
Commit: 7e7350285dc22764f599671d874617c0eea093e5
Parents: 9498e52
Author: Lu WANG <lu...@databricks.com>
Authored: Tue May 8 21:20:58 2018 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 8 21:20:58 2018 -0700
----------------------------------------------------------------------
.../spark/ml/classification/DecisionTreeClassifier.scala | 9 ++++++---
.../org/apache/spark/ml/classification/LinearSVC.scala | 9 ++++++---
.../org/apache/spark/ml/classification/NaiveBayes.scala | 3 ++-
.../org/apache/spark/ml/classification/OneVsRest.scala | 4 ++--
.../spark/ml/classification/RandomForestClassifier.scala | 4 +++-
5 files changed, 19 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 57797d1..c9786f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setSeed(value: Long): this.type = set(seed, value)
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val instr = Instrumentation.create(this, oldDataset)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
@@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 80c537e..38eb045 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
Instance(label, weight, features)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth)
@@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
+ instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
+ instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
bcFeaturesStd.destroy(blocking = false)
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 45fb585..1dde18d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
+ val instr = Instrumentation.create(this, dataset)
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
@@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
}
}
- val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)
http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7df53a6..3474b61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
- instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
+ instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
@@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
getClassifier match {
case _: HasWeightCol => true
case c =>
- logWarning(s"weightCol is ignored, as it is not supported by $c now.")
+ instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f1ef26a..040db3b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
set(featureSubsetStrategy, value)
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
@@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
@@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
instr.logSuccess(m)
m
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org