You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/07/17 20:12:18 UTC

spark git commit: [SPARK-24747][ML] Make Instrumentation class more flexible

Repository: spark
Updated Branches:
  refs/heads/master 7688ce88b -> 912634b00


[SPARK-24747][ML] Make Instrumentation class more flexible

## What changes were proposed in this pull request?

This PR updates the Instrumentation class to make it more flexible and a little bit easier to use. When these APIs are merged, I'll followup with a PR to update the training code to use these new APIs so we can remove the old APIs. These changes are all to private APIs so this PR doesn't make any user facing changes.

## How was this patch tested?

Existing tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Bago Amirbekian <ba...@databricks.com>

Closes #21719 from MrBago/new-instrumentation-apis.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/912634b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/912634b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/912634b0

Branch: refs/heads/master
Commit: 912634b004c2302533a8a8501b4ecb803d17e335
Parents: 7688ce8
Author: Bago Amirbekian <ba...@databricks.com>
Authored: Tue Jul 17 13:11:52 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Jul 17 13:11:52 2018 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |   8 +-
 .../spark/ml/tree/impl/RandomForest.scala       |   2 +-
 .../spark/ml/tuning/ValidatorParams.scala       |   2 +-
 .../apache/spark/ml/util/Instrumentation.scala  | 128 ++++++++++++-------
 .../apache/spark/mllib/clustering/KMeans.scala  |   4 +-
 5 files changed, 93 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/912634b0/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 92e342e..25fb9c8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
@@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") (
 
   protected[spark] def train(
       dataset: Dataset[_],
-      handlePersistence: Boolean): LogisticRegressionModel = {
+      handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
       dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
@@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") (
 
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
 
-    val instr = Instrumentation.create(this, dataset)
+    instr.logPipelineStage(this)
+    instr.logDataset(dataset)
     instr.logParams(regParam, elasticNetParam, standardization, threshold,
       maxIter, tol, fitIntercept)
 
@@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") (
         objectiveHistory)
     }
     model.setSummary(Some(logRegSummary))
-    instr.logSuccess(model)
-    model
   }
 
   @Since("1.4.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/912634b0/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 9058701..bb3f3a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging {
       numTrees: Int,
       featureSubsetStrategy: String,
       seed: Long,
-      instr: Option[Instrumentation[_]],
+      instr: Option[Instrumentation],
       prune: Boolean = true, // exposed for testing only, real trees are always pruned
       parentUID: Option[String] = None): Array[DecisionTreeModel] = {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/912634b0/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 363304e..1358288 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
   /**
    * Instrumentation logging for tuning params including the inner estimator and evaluator info.
    */
-  protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
+  protected def logTuningParams(instrumentation: Instrumentation): Unit = {
     instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
     instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
     instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)

http://git-wip-us.apache.org/repos/asf/spark/blob/912634b0/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 11f46eb..2e43a9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -19,15 +19,16 @@ package org.apache.spark.ml.util
 
 import java.util.UUID
 
-import scala.reflect.ClassTag
+import scala.util.{Failure, Success, Try}
+import scala.util.control.NonFatal
 
 import org.json4s._
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.{Estimator, Model, PipelineStage}
+import org.apache.spark.ml.param.{Param, Params}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
 import org.apache.spark.util.Utils
@@ -35,29 +36,44 @@ import org.apache.spark.util.Utils
 /**
  * A small wrapper that defines a training session for an estimator, and some methods to log
  * useful information during this session.
- *
- * A new instance is expected to be created within fit().
- *
- * @param estimator the estimator that is being fit
- * @param dataset the training dataset
- * @tparam E the type of the estimator
  */
-private[spark] class Instrumentation[E <: Estimator[_]] private (
-    val estimator: E,
-    val dataset: RDD[_]) extends Logging {
+private[spark] class Instrumentation extends Logging {
 
   private val id = UUID.randomUUID()
-  private val prefix = {
+  private val shortId = id.toString.take(8)
+  private val prefix = s"[$shortId] "
+
+  // TODO: remove stage
+  var stage: Params = _
+  // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
+  private def this(estimator: Estimator[_], dataset: RDD[_]) = {
+    this()
+    logPipelineStage(estimator)
+    logDataset(dataset)
+  }
+
+  /**
+   * Log some info about the pipeline stage being fit.
+   */
+  def logPipelineStage(stage: PipelineStage): Unit = {
+    this.stage = stage
     // estimator.getClass.getSimpleName can cause Malformed class name error,
     // call safer `Utils.getSimpleName` instead
-    val className = Utils.getSimpleName(estimator.getClass)
-    s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
+    val className = Utils.getSimpleName(stage.getClass)
+    logInfo(s"Stage class: $className")
+    logInfo(s"Stage uid: ${stage.uid}")
   }
 
-  init()
+  /**
+   * Log some data about the dataset being fit.
+   */
+  def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd)
 
-  private def init(): Unit = {
-    log(s"training: numPartitions=${dataset.partitions.length}" +
+  /**
+   * Log some data about the dataset being fit.
+   */
+  def logDataset(dataset: RDD[_]): Unit = {
+    logInfo(s"training: numPartitions=${dataset.partitions.length}" +
       s" storageLevel=${dataset.getStorageLevel}")
   }
 
@@ -90,22 +106,24 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
   }
 
   /**
-   * Alias for logInfo, see above.
-   */
-  def log(msg: String): Unit = logInfo(msg)
-
-  /**
    * Logs the value of the given parameters for the estimator being used in this session.
    */
-  def logParams(params: Param[_]*): Unit = {
+  def logParams(hasParams: Params, params: Param[_]*): Unit = {
     val pairs: Seq[(String, JValue)] = for {
       p <- params
-      value <- estimator.get(p)
+      value <- hasParams.get(p)
     } yield {
       val cast = p.asInstanceOf[Param[Any]]
       p.name -> parse(cast.jsonEncode(value))
     }
-    log(compact(render(map2jvalue(pairs.toMap))))
+    logInfo(compact(render(map2jvalue(pairs.toMap))))
+  }
+
+  // TODO: remove this
+  def logParams(params: Param[_]*): Unit = {
+    require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" +
+      " Params must be provided explicitly).")
+    logParams(stage, params: _*)
   }
 
   def logNumFeatures(num: Long): Unit = {
@@ -124,35 +142,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
    * Logs the value with customized name field.
    */
   def logNamedValue(name: String, value: String): Unit = {
-    log(compact(render(name -> value)))
+    logInfo(compact(render(name -> value)))
   }
 
   def logNamedValue(name: String, value: Long): Unit = {
-    log(compact(render(name -> value)))
+    logInfo(compact(render(name -> value)))
   }
 
   def logNamedValue(name: String, value: Double): Unit = {
-    log(compact(render(name -> value)))
+    logInfo(compact(render(name -> value)))
   }
 
   def logNamedValue(name: String, value: Array[String]): Unit = {
-    log(compact(render(name -> compact(render(value.toSeq)))))
+    logInfo(compact(render(name -> compact(render(value.toSeq)))))
   }
 
   def logNamedValue(name: String, value: Array[Long]): Unit = {
-    log(compact(render(name -> compact(render(value.toSeq)))))
+    logInfo(compact(render(name -> compact(render(value.toSeq)))))
   }
 
   def logNamedValue(name: String, value: Array[Double]): Unit = {
-    log(compact(render(name -> compact(render(value.toSeq)))))
+    logInfo(compact(render(name -> compact(render(value.toSeq)))))
   }
 
 
+  // TODO: Remove this (possibly replace with logModel?)
   /**
    * Logs the successful completion of the training session.
    */
   def logSuccess(model: Model[_]): Unit = {
-    log(s"training finished")
+    logInfo(s"training finished")
+  }
+
+  def logSuccess(): Unit = {
+    logInfo("training finished")
+  }
+
+  /**
+   * Logs an exception raised during a training session.
+   */
+  def logFailure(e: Throwable): Unit = {
+    val msg = e.getStackTrace.mkString("\n")
+    super.logError(msg)
   }
 }
 
@@ -169,22 +200,33 @@ private[spark] object Instrumentation {
     val varianceOfLabels = "varianceOfLabels"
   }
 
+  // TODO: Remove these
   /**
    * Creates an instrumentation object for a training session.
    */
-  def create[E <: Estimator[_]](
-      estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
-    create[E](estimator, dataset.rdd)
+  def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = {
+    create(estimator, dataset.rdd)
   }
 
   /**
    * Creates an instrumentation object for a training session.
    */
-  def create[E <: Estimator[_]](
-      estimator: E, dataset: RDD[_]): Instrumentation[E] = {
-    new Instrumentation[E](estimator, dataset)
+  def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = {
+    new Instrumentation(estimator, dataset)
+  }
+  // end remove
+
+  def instrumented[T](body: (Instrumentation => T)): T = {
+    val instr = new Instrumentation()
+    Try(body(instr)) match {
+      case Failure(NonFatal(e)) =>
+        instr.logFailure(e)
+        throw e
+      case Success(result) =>
+        instr.logSuccess()
+        result
+    }
   }
-
 }
 
 /**
@@ -193,7 +235,7 @@ private[spark] object Instrumentation {
  * will log via it, otherwise will log via common logger.
  */
 private[spark] class OptionalInstrumentation private(
-    val instrumentation: Option[Instrumentation[_ <: Estimator[_]]],
+    val instrumentation: Option[Instrumentation],
     val className: String) extends Logging {
 
   protected override def logName: String = className
@@ -225,9 +267,9 @@ private[spark] object OptionalInstrumentation {
   /**
    * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object.
    */
-  def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = {
+  def create(instr: Instrumentation): OptionalInstrumentation = {
     new OptionalInstrumentation(Some(instr),
-      instr.estimator.getClass.getName.stripSuffix("$"))
+      instr.stage.getClass.getName.stripSuffix("$"))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/912634b0/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 37ae8b1..4f554f4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -235,7 +235,7 @@ class KMeans private (
 
   private[spark] def run(
       data: RDD[Vector],
-      instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
+      instr: Option[Instrumentation]): KMeansModel = {
 
     if (data.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data is not directly cached, which may hurt performance if its"
@@ -264,7 +264,7 @@ class KMeans private (
    */
   private def runAlgorithm(
       data: RDD[VectorWithNorm],
-      instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
+      instr: Option[Instrumentation]): KMeansModel = {
 
     val sc = data.sparkContext
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org