You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2018/05/17 18:13:19 UTC

spark git commit: [SPARK-24115] Have logging pass through instrumentation class.

Repository: spark
Updated Branches:
  refs/heads/master 8a837bf4f -> a7a9b1837


[SPARK-24115] Have logging pass through instrumentation class.

## What changes were proposed in this pull request?

Fixes to tuning instrumentation.

## How was this patch tested?

Existing tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Bago Amirbekian <ba...@databricks.com>

Closes #21340 from MrBago/tunning-instrumentation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7a9b183
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7a9b183
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7a9b183

Branch: refs/heads/master
Commit: a7a9b1837808b281f47643490abcf054f6de7b50
Parents: 8a837bf
Author: Bago Amirbekian <ba...@databricks.com>
Authored: Thu May 17 11:13:16 2018 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu May 17 11:13:16 2018 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++-----
 .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 10 +++++-----
 .../scala/org/apache/spark/ml/util/Instrumentation.scala  |  7 +++++++
 3 files changed, 17 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7a9b183/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5e916cc..f327f37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
     val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
       val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
       val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
-      logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+      instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")
 
       // Fit models in a Future for training in parallel
       val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
@@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
           }
           // TODO: duplicate evaluator to take extra params from input
           val metric = eval.evaluate(model.transform(validationDataset, paramMap))
-          logDebug(s"Got metric $metric for model trained with $paramMap.")
+          instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
           metric
         } (executionContext)
       }
@@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
       foldMetrics
     }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
 
-    logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+    instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
     val (bestMetric, bestIndex) =
       if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
       else metrics.zipWithIndex.minBy(_._1)
-    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
-    logInfo(s"Best cross-validation metric: $bestMetric.")
+    instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+    instr.logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
     instr.logSuccess(bestModel)
     copyValues(new CrossValidatorModel(uid, bestModel, metrics)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a9b183/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 13369c4..14d6a69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     } else None
 
     // Fit models in a Future for training in parallel
-    logDebug(s"Train split with multiple sets of parameters.")
+    instr.logDebug(s"Train split with multiple sets of parameters.")
     val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
       Future[Double] {
         val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
@@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
         }
         // TODO: duplicate evaluator to take extra params from input
         val metric = eval.evaluate(model.transform(validationDataset, paramMap))
-        logDebug(s"Got metric $metric for model trained with $paramMap.")
+        instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
         metric
       } (executionContext)
     }
@@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     trainingDataset.unpersist()
     validationDataset.unpersist()
 
-    logInfo(s"Train validation split metrics: ${metrics.toSeq}")
+    instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}")
     val (bestMetric, bestIndex) =
       if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
       else metrics.zipWithIndex.minBy(_._1)
-    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
-    logInfo(s"Best train validation split metric: $bestMetric.")
+    instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+    instr.logInfo(s"Best train validation split metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
     instr.logSuccess(bestModel)
     copyValues(new TrainValidationSplitModel(uid, bestModel, metrics)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a9b183/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 3247c39..467130b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -59,6 +59,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
   }
 
   /**
+   * Logs a debug message with a prefix that uniquely identifies the training session.
+   */
+  override def logDebug(msg: => String): Unit = {
+    super.logDebug(prefix + msg)
+  }
+
+  /**
    * Logs a warning message with a prefix that uniquely identifies the training session.
    */
   override def logWarning(msg: => String): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org