You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/01/07 02:53:21 UTC

spark git commit: [SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit

Repository: spark
Updated Branches:
  refs/heads/master b59cddaba -> d60f6f62d


[SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit

## What changes were proposed in this pull request?

Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions.

## How was this patch tested?

Ran unit tests and checked the log file (see output in comments).

Author: sueann <su...@databricks.com>

Closes #16480 from sueann/SPARK-18194.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d60f6f62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d60f6f62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d60f6f62

Branch: refs/heads/master
Commit: d60f6f62d00ffccc40ed72e15349358fe3543311
Parents: b59cdda
Author: sueann <su...@databricks.com>
Authored: Fri Jan 6 18:53:16 2017 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Jan 6 18:53:16 2017 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/OneVsRest.scala   |  7 +++++++
 .../scala/org/apache/spark/ml/recommendation/ALS.scala   |  6 +++---
 .../org/apache/spark/ml/tuning/CrossValidator.scala      |  6 ++++++
 .../apache/spark/ml/tuning/TrainValidationSplit.scala    |  5 +++++
 .../org/apache/spark/ml/tuning/ValidatorParams.scala     | 11 ++++++++++-
 .../scala/org/apache/spark/ml/util/Instrumentation.scala |  8 ++++++--
 6 files changed, 37 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index e58b30d..cbd508a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -308,6 +308,10 @@ final class OneVsRest @Since("1.4.0") (
   override def fit(dataset: Dataset[_]): OneVsRestModel = {
     transformSchema(dataset.schema)
 
+    val instr = Instrumentation.create(this, dataset)
+    instr.logParams(labelCol, featuresCol, predictionCol)
+    instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
+
     // determine number of classes either from metadata if provided, or via computation.
     val labelSchema = dataset.schema($(labelCol))
     val computeNumClasses: () => Int = () => {
@@ -316,6 +320,7 @@ final class OneVsRest @Since("1.4.0") (
       maxLabelIndex.toInt + 1
     }
     val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
+    instr.logNumClasses(numClasses)
 
     val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
 
@@ -339,6 +344,7 @@ final class OneVsRest @Since("1.4.0") (
       paramMap.put(classifier.predictionCol -> getPredictionCol)
       classifier.fit(trainingDataset, paramMap)
     }.toArray[ClassificationModel[_, _]]
+    instr.logNumFeatures(models.head.numFeatures)
 
     if (handlePersistence) {
       multiclassLabeled.unpersist()
@@ -352,6 +358,7 @@ final class OneVsRest @Since("1.4.0") (
       case attr: Attribute => attr
     }
     val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
+    instr.logSuccess(model)
     copyValues(model)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index b466e2e..cdea90e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -457,8 +457,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
       .map { row =>
         Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
       }
-    val instrLog = Instrumentation.create(this, ratings)
-    instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
+    val instr = Instrumentation.create(this, ratings)
+    instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
                        userCol, itemCol, ratingCol, predictionCol, maxIter,
                        regParam, nonnegative, checkpointInterval, seed)
     val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
@@ -471,7 +471,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
     val userDF = userFactors.toDF("id", "features")
     val itemDF = itemFactors.toDF("id", "features")
     val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
-    instrLog.logSuccess(model)
+    instr.logSuccess(model)
     copyValues(model)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 85191d4..2012d6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -101,6 +101,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
     val epm = $(estimatorParamMaps)
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
+
+    val instr = Instrumentation.create(this, dataset)
+    instr.logParams(numFolds, seed)
+    logTuningParams(instr)
+
     val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
     splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
       val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
@@ -127,6 +132,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+    instr.logSuccess(bestModel)
     copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 5d1a39f..db7c9d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -97,6 +97,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
 
+    val instr = Instrumentation.create(this, dataset)
+    instr.logParams(trainRatio, seed)
+    logTuningParams(instr)
+
     val Array(trainingDataset, validationDataset) =
       dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
     trainingDataset.cache()
@@ -123,6 +127,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best train validation split metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+    instr.logSuccess(bestModel)
     copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 26fd738..d55eb14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
 import org.apache.spark.ml.param.shared.HasSeed
-import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable}
+import org.apache.spark.ml.util._
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.sql.types.StructType
 
@@ -76,6 +76,15 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
     }
     est.copy(firstEstimatorParamMap).transformSchema(schema)
   }
+
+  /**
+   * Instrumentation logging for tuning params including the inner estimator and evaluator info.
+   */
+  protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
+    instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
+    instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
+    instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
+  }
 }
 
 private[ml] object ValidatorParams {

http://git-wip-us.apache.org/repos/asf/spark/blob/d60f6f62/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 71a6266..a279436 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -87,8 +87,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
   /**
    * Logs the value with customized name field.
    */
-  def logNamedValue(name: String, num: Long): Unit = {
-    log(compact(render(name -> num)))
+  def logNamedValue(name: String, value: String): Unit = {
+    log(compact(render(name -> value)))
+  }
+
+  def logNamedValue(name: String, value: Long): Unit = {
+    log(compact(render(name -> value)))
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org