You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/12 22:35:28 UTC

spark git commit: [SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All

Repository: spark
Updated Branches:
  refs/heads/master 5438f49cc -> 595a67589


[SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All

initial cut of one against all. test code is a scaffolding , not fully implemented.
This WIP is to gather early feedback.

Author: Ram Sriharsha <rs...@hw11853.local>

Closes #5830 from harsha2010/reduction and squashes the following commits:

5f4b495 [Ram Sriharsha] Fix Test
386e98b [Ram Sriharsha] Style fix
49b4a17 [Ram Sriharsha] Simplify the test
02279cc [Ram Sriharsha] Output Label Metadata in Prediction Col
bc78032 [Ram Sriharsha] Code Review Updates
8ce4845 [Ram Sriharsha] Merge with Master
2a807be [Ram Sriharsha] Merge branch 'master' into reduction
e21bfcc [Ram Sriharsha] Style Fix
5614f23 [Ram Sriharsha] Style Fix
c75583a [Ram Sriharsha] Cleanup
7a5f136 [Ram Sriharsha] Fix TODOs
804826b [Ram Sriharsha] Merge with Master
1448a5f [Ram Sriharsha] Style Fix
6e47807 [Ram Sriharsha] Style Fix
d63e46b [Ram Sriharsha] Incorporate Code Review Feedback
ced68b5 [Ram Sriharsha] Refactor OneVsAll to implement Predictor
78fa82a [Ram Sriharsha] extra line
0dfa1fb [Ram Sriharsha] Fix inexhaustive match cases that may arise from UnresolvedAttribute
a59a4f4 [Ram Sriharsha] @Experimental
4167234 [Ram Sriharsha] Merge branch 'master' into reduction
868a4fd [Ram Sriharsha] @Experimental
041d905 [Ram Sriharsha] Code Review Fixes
df188d8 [Ram Sriharsha] Style fix
612ec48 [Ram Sriharsha] Style Fix
6ef43d3 [Ram Sriharsha] Prefer Unresolved Attribute to Option: Java APIs are cleaner
6bf6bff [Ram Sriharsha] Update OneHotEncoder to new API
e29cb89 [Ram Sriharsha] Merge branch 'master' into reduction
1c7fa44 [Ram Sriharsha] Fix Tests
ca83672 [Ram Sriharsha] Incorporate Code Review Feedback + Rename to OneVsRestClassifier
221beeed [Ram Sriharsha] Upgrade to use Copy method for cloning Base Classifiers
26f1ddb [Ram Sriharsha] Merge with SPARK-5956 API changes
9738744 [Ram Sriharsha] Merge branch 'master' into reduction
1a3e375 [Ram Sriharsha] More efficient Implementation: Use withColumn to generate label column dynamically
32e0189 [Ram Sriharsha] Restrict reduction to Margin Based Classifiers
ff272da [Ram Sriharsha] Style fix
28771f5 [Ram Sriharsha] Add Tests for Multiclass to Binary Reduction
b60f874 [Ram Sriharsha] Fix Style issues in Test
3191cdf [Ram Sriharsha] Remove this test, accidental commit
23f056c [Ram Sriharsha] Fix Headers for test
1b5e929 [Ram Sriharsha] Fix Style issues and add Header
8752863 [Ram Sriharsha] [SPARK-7015][MLLib][WIP] Multiclass to Binary Reduction: One Against All


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/595a6758
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/595a6758
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/595a6758

Branch: refs/heads/master
Commit: 595a67589a42f8025d3e5fd4da413b1faa2e14bf
Parents: 5438f49
Author: Ram Sriharsha <rs...@hw11853.local>
Authored: Tue May 12 13:35:12 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue May 12 13:35:12 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   |   3 +-
 .../spark/ml/attribute/AttributeGroup.scala     |   1 +
 .../spark/ml/attribute/AttributeType.scala      |   8 +
 .../apache/spark/ml/attribute/attributes.scala  |  37 +++-
 .../apache/spark/ml/feature/VectorIndexer.scala |   4 +-
 .../apache/spark/ml/reduction/OneVsRest.scala   | 211 +++++++++++++++++++
 .../apache/spark/ml/util/MetadataUtils.scala    |   7 +-
 .../spark/ml/reduction/JavaOneVsRestSuite.java  |  85 ++++++++
 .../spark/ml/attribute/AttributeSuite.scala     |  10 +-
 .../spark/ml/reduction/OneVsRestSuite.scala     | 113 ++++++++++
 10 files changed, 471 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 0e53877..f6a5f27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -113,7 +113,8 @@ abstract class Predictor[
    *
    * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
    */
-  protected def featuresDataType: DataType = new VectorUDT
+  @DeveloperApi
+  private[ml] def featuresDataType: DataType = new VectorUDT
 
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = true, featuresDataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index d7dee8f..f5f37aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -123,6 +123,7 @@ class AttributeGroup private (
           nominalMetadata += nominal.toMetadataImpl(withType = false)
         case binary: BinaryAttribute =>
           binaryMetadata += binary.toMetadataImpl(withType = false)
+        case UnresolvedAttribute =>
       }
       val attrBldr = new MetadataBuilder
       if (numericMetadata.nonEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
index 65e7e43..a83febd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
@@ -43,6 +43,12 @@ object AttributeType {
     Binary
   }
 
+  /** Unresolved type. */
+  val Unresolved: AttributeType = {
+    case object Unresolved extends AttributeType("unresolved")
+    Unresolved
+  }
+
   /**
    * Gets the [[AttributeType]] object from its name.
    * @param name attribute type name: "numeric", "nominal", or "binary"
@@ -54,6 +60,8 @@ object AttributeType {
       Nominal
     } else if (name == Binary.name) {
       Binary
+    } else if (name == Unresolved.name) {
+      Unresolved
     } else {
       throw new IllegalArgumentException(s"Cannot recognize type $name.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 5717d6e..e8f7f15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -125,7 +125,13 @@ private[attribute] trait AttributeFactory {
    */
   def fromStructField(field: StructField): Attribute = {
     require(field.dataType == DoubleType)
-    fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name)
+    val metadata = field.metadata
+    val mlAttr = AttributeKeys.ML_ATTR
+    if (metadata.contains(mlAttr)) {
+      fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
+    } else {
+      UnresolvedAttribute
+    }
   }
 }
 
@@ -535,3 +541,32 @@ object BinaryAttribute extends AttributeFactory {
     new BinaryAttribute(name, index, values)
   }
 }
+
+/**
+ * An unresolved attribute.
+ */
+object UnresolvedAttribute extends Attribute {
+
+  override def attrType: AttributeType = AttributeType.Unresolved
+
+  override def withIndex(index: Int): Attribute = this
+
+  override def isNumeric: Boolean = false
+
+  override def withoutIndex: Attribute = this
+
+  override def isNominal: Boolean = false
+
+  override def name: Option[String] = None
+
+  override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
+    Metadata.empty
+  }
+
+  override def withoutName: Attribute = this
+
+  override def index: Option[Int] = None
+
+  override def withName(name: String): Attribute = this
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 07ea579..2e6313a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
@@ -375,6 +375,8 @@ class VectorIndexerModel private[ml] (
           }
         case (origAttr: Attribute, featAttr: NumericAttribute) =>
           origAttr.withIndex(featAttr.index.get)
+        case (origAttr: Attribute, _) =>
+          origAttr
       }
     } else {
       partialFeatureAttributes

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
new file mode 100644
index 0000000..0a6728e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction
+
+import java.util.UUID
+
+import scala.language.existentials
+
+import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams {
+
+  type ClassifierType = Classifier[F, E, M] forSome {
+    type F
+    type M <: ClassificationModel[F, M]
+    type E <:  Classifier[F, E, M]
+  }
+
+  /**
+   * param for the base binary classifier that we reduce multiclass classification into.
+   * @group param
+   */
+  val classifier: Param[ClassifierType]  =
+    new Param(this, "classifier", "base binary classifier ")
+
+  /** @group getParam */
+  def getClassifier: ClassifierType = $(classifier)
+
+}
+
+/**
+ * Model produced by [[OneVsRest]].
+ * Stores the models resulting from training k different classifiers:
+ * one for each class.
+ * Each example is scored against all k models and the model with highest score
+ * is picked to label the example.
+ * TODO: API may need to change when we introduce a ClassificationModel trait as the public API
+ * @param parent
+ * @param labelMetadata Metadata of label column if it exists, or Nominal attribute
+ *                      representing the number of classes in training dataset otherwise.
+ * @param models the binary classification models for reduction.
+ *               The i-th model is produced by testing the i-th class vs the rest.
+ */
+@AlphaComponent
+class OneVsRestModel(
+      override val parent: OneVsRest,
+      labelMetadata: Metadata,
+      val models: Array[_ <: ClassificationModel[_,_]])
+  extends Model[OneVsRestModel] with OneVsRestParams {
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
+  }
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    // Check schema
+    transformSchema(dataset.schema, logging = true)
+
+    // determine the input columns: these need to be passed through
+    val origCols = dataset.schema.map(f => col(f.name))
+
+    // add an accumulator column to store predictions of all the models
+    val accColName = "mbc$acc" + UUID.randomUUID().toString
+    val init: () => Map[Int, Double] = () => {Map()}
+    val mapType = MapType(IntegerType, DoubleType, false)
+    val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
+
+    // persist if underlying dataset is not persistent.
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    if (handlePersistence) {
+      newDataset.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+
+    // update the accumulator column with the result of prediction of models
+    val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
+      case (df, (model, index)) => {
+        val rawPredictionCol = model.getRawPredictionCol
+        val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
+
+        // add temporary column to store intermediate scores and update
+        val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
+        val update: (Map[Int, Double], Vector) => Map[Int, Double]  =
+          (predictions: Map[Int, Double], prediction: Vector) => {
+            predictions + ((index, prediction(1)))
+        }
+        val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
+        val transformedDataset = model.transform(df).select(columns:_*)
+        val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
+        val newColumns = origCols ++ List(col(tmpColName))
+
+        // switch out the intermediate column with the accumulator column
+        updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
+      }
+    }
+
+    if (handlePersistence) {
+      newDataset.unpersist()
+    }
+
+    // output the index of the classifier with highest confidence as prediction
+    val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
+      predictions.maxBy(_._2)._1.toDouble
+    }
+
+    // output label and label metadata as prediction
+    val labelUdf = callUDF(label, DoubleType, col(accColName))
+    aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+  }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Reduction of Multiclass Classification to Binary Classification.
+ * Performs reduction using one against all strategy.
+ * For a multiclass classification with k classes, train k models (one per class).
+ * Each example is scored against all k models and the model with highest score
+ * is picked to label the example.
+ */
+@Experimental
+final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
+
+  /** @group setParam */
+  // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
+  def setClassifier(value: Classifier[_,_,_]): this.type = {
+    set(classifier, value.asInstanceOf[ClassifierType])
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
+  }
+
+  override def fit(dataset: DataFrame): OneVsRestModel = {
+    // determine number of classes either from metadata if provided, or via computation.
+    val labelSchema = dataset.schema($(labelCol))
+    val computeNumClasses: () => Int = () => {
+      val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+      // classes are assumed to be numbered from 0,...,maxLabelIndex
+      maxLabelIndex.toInt + 1
+    }
+    val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
+
+    val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
+
+    // persist if underlying dataset is not persistent.
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    if (handlePersistence) {
+      multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+
+    // create k columns, one for each binary classifier.
+    val models = Range(0, numClasses).par.map { index =>
+
+      val label: Double => Double = (label: Double) => {
+        if (label.toInt == index) 1.0 else 0.0
+      }
+
+      // generate new label metadata for the binary problem.
+      // TODO: use when ... otherwise after SPARK-7321 is merged
+      val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
+      val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
+      val labelColName = "mc2b$" + index
+      val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
+      val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
+      val classifier = getClassifier
+      classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+    }.toArray[ClassificationModel[_,_]]
+
+    if (handlePersistence) {
+      multiclassLabeled.unpersist()
+    }
+
+    // extract label metadata from label column if present, or create a nominal attribute
+    // to output the number of labels
+    val labelAttribute = Attribute.fromStructField(labelSchema) match {
+      case _: NumericAttribute | UnresolvedAttribute => {
+        NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+      }
+      case attr: Attribute => attr
+    }
+    copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index c84c8b4..56075c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.util
 import scala.collection.immutable.HashMap
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
-  NumericAttribute}
+import org.apache.spark.ml.attribute._
 import org.apache.spark.sql.types.StructField
 
 
@@ -39,9 +38,9 @@ object MetadataUtils {
    */
   def getNumClasses(labelSchema: StructField): Option[Int] = {
     Attribute.fromStructField(labelSchema) match {
-      case numAttr: NumericAttribute => None
       case binAttr: BinaryAttribute => Some(2)
       case nomAttr: NominalAttribute => nomAttr.getNumValues
+      case _: NumericAttribute | UnresolvedAttribute => None
     }
   }
 
@@ -65,7 +64,7 @@ object MetadataUtils {
           Iterator()
         } else {
           attr match {
-            case numAttr: NumericAttribute => Iterator()
+            case _: NumericAttribute | UnresolvedAttribute => Iterator()
             case binAttr: BinaryAttribute => Iterator(idx -> 2)
             case nomAttr: NominalAttribute =>
               nomAttr.getNumValues match {

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
new file mode 100644
index 0000000..40a90ae
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import static scala.collection.JavaConversions.seqAsJavaList;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaOneVsRestSuite implements Serializable {
+
+    private transient JavaSparkContext jsc;
+    private transient SQLContext jsql;
+    private transient DataFrame dataset;
+    private transient JavaRDD<LabeledPoint> datasetRDD;
+
+    @Before
+    public void setUp() {
+        jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
+        jsql = new SQLContext(jsc);
+        int nPoints = 3;
+
+        /**
+         * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+         * As a result, we are actually drawing samples from probability distribution of built model.
+         */
+        double[] weights = {
+                -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+                -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+
+        double[] xMean = {5.843, 3.057, 3.758, 1.199};
+        double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+        List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
+                weights, xMean, xVariance, true, nPoints, 42));
+        datasetRDD = jsc.parallelize(points, 2);
+        dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+    }
+
+    @After
+    public void tearDown() {
+        jsc.stop();
+        jsc = null;
+    }
+
+    @Test
+    public void oneVsRestDefaultParams() {
+        OneVsRest ova = new OneVsRest();
+        ova.setClassifier(new LogisticRegression());
+        Assert.assertEquals(ova.getLabelCol() , "label");
+        Assert.assertEquals(ova.getPredictionCol() , "prediction");
+        OneVsRestModel ovaModel = ova.fit(dataset);
+        DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+        predictions.collectAsList();
+        Assert.assertEquals(ovaModel.getLabelCol(), "label");
+        Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
+    }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 3e1a719..ec9b717 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata}
+import org.apache.spark.sql.types._
 
 class AttributeSuite extends FunSuite {
 
@@ -209,4 +209,12 @@ class AttributeSuite extends FunSuite {
     intercept[IllegalArgumentException](attr.withName(""))
     intercept[IllegalArgumentException](attr.withIndex(-1))
   }
+
+  test("attribute from struct field") {
+    val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata()
+    val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty)
+    assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
+    val fldWithMeta = new StructField("x", DoubleType, false, metadata)
+    assert(Attribute.fromStructField(fldWithMeta).isNumeric)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/595a6758/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
new file mode 100644
index 0000000..ebec7c6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+
+  @transient var sqlContext: SQLContext = _
+  @transient var dataset: DataFrame = _
+  @transient var rdd: RDD[LabeledPoint] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+    val nPoints = 1000
+
+    /**
+     * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+     * As a result, we are actually drawing samples from probability distribution of built model.
+     */
+    val weights = Array(
+      -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+      -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+    val xMean = Array(5.843, 3.057, 3.758, 1.199)
+    val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+    rdd = sc.parallelize(generateMultinomialLogisticInput(
+      weights, xMean, xVariance, true, nPoints, 42), 2)
+    dataset = sqlContext.createDataFrame(rdd)
+  }
+
+  test("one-vs-rest: default params") {
+    val numClasses = 3
+    val ova = new OneVsRest()
+    ova.setClassifier(new LogisticRegression)
+    assert(ova.getLabelCol === "label")
+    assert(ova.getPredictionCol === "prediction")
+    val ovaModel = ova.fit(dataset)
+    assert(ovaModel.models.size === numClasses)
+    val transformedDataset = ovaModel.transform(dataset)
+
+    // check for label metadata in prediction col
+    val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
+    assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
+
+    val ovaResults = transformedDataset
+      .select("prediction", "label")
+      .map(row => (row.getDouble(0), row.getDouble(1)))
+
+    val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
+    lr.optimizer.setRegParam(0.1).setNumIterations(100)
+
+    val model = lr.run(rdd)
+    val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
+    // determine the #confusion matrix in each class.
+    // bound how much error we allow compared to multinomial logistic regression.
+    val expectedMetrics = new MulticlassMetrics(results)
+    val ovaMetrics = new MulticlassMetrics(ovaResults)
+    assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
+  }
+
+  test("one-vs-rest: pass label metadata correctly during train") {
+    val numClasses = 3
+    val ova = new OneVsRest()
+    ova.setClassifier(new MockLogisticRegression)
+
+    val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+    val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
+    val features = dataset("features").as("features")
+    val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
+    ova.fit(datasetWithLabelMetadata)
+  }
+}
+
+private class MockLogisticRegression extends LogisticRegression {
+
+  setMaxIter(1)
+
+  override protected def train(dataset: DataFrame): LogisticRegressionModel = {
+    val labelSchema = dataset.schema($(labelCol))
+    // check for label attribute propagation.
+    assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
+    super.train(dataset)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org