You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/13 01:42:32 UTC
spark git commit: [SPARK-7573] [ML] OneVsRest cleanups
Repository: spark
Updated Branches:
refs/heads/master f0c1bc347 -> 96c4846db
[SPARK-7573] [ML] OneVsRest cleanups
Minor cleanups discussed with [~mengxr]:
* move OneVsRest from reduction to classification sub-package
* make model constructor private
Some doc cleanups too
CC: harsha2010 Could you please verify this looks OK? Thanks!
Author: Joseph K. Bradley <jo...@databricks.com>
Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits:
4ecd48d [Joseph K. Bradley] org imports
430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes
9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96c4846d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96c4846d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96c4846d
Branch: refs/heads/master
Commit: 96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8
Parents: f0c1bc3
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue May 12 16:42:30 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 12 16:42:30 2015 -0700
----------------------------------------------------------------------
.../spark/ml/classification/OneVsRest.scala | 209 ++++++++++++++++++
.../apache/spark/ml/reduction/OneVsRest.scala | 211 -------------------
.../ml/classification/JavaOneVsRestSuite.java | 82 +++++++
.../spark/ml/reduction/JavaOneVsRestSuite.java | 85 --------
.../ml/classification/OneVsRestSuite.scala | 110 ++++++++++
.../spark/ml/reduction/OneVsRestSuite.scala | 113 ----------
6 files changed, 401 insertions(+), 409 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
new file mode 100644
index 0000000..afb8d75
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import java.util.UUID
+
+import scala.language.existentials
+
+import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams {
+
+ type ClassifierType = Classifier[F, E, M] forSome {
+ type F
+ type M <: ClassificationModel[F, M]
+ type E <: Classifier[F, E, M]
+ }
+
+ /**
+ * param for the base binary classifier that we reduce multiclass classification into.
+ * @group param
+ */
+ val classifier: Param[ClassifierType] =
+ new Param(this, "classifier", "base binary classifier ")
+
+ /** @group getParam */
+ def getClassifier: ClassifierType = $(classifier)
+
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by [[OneVsRest]].
+ * This stores the models resulting from training k binary classifiers: one for each class.
+ * Each example is scored against all k models, and the model with the highest score
+ * is picked to label the example.
+ *
+ * @param labelMetadata Metadata of label column if it exists, or Nominal attribute
+ * representing the number of classes in training dataset otherwise.
+ * @param models The binary classification models for the reduction.
+ * The i-th model is produced by testing the i-th class (taking label 1) vs the rest
+ * (taking label 0).
+ */
+@AlphaComponent
+class OneVsRestModel private[ml] (
+ override val parent: OneVsRest,
+ labelMetadata: Metadata,
+ val models: Array[_ <: ClassificationModel[_,_]])
+ extends Model[OneVsRestModel] with OneVsRestParams {
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ // Check schema
+ transformSchema(dataset.schema, logging = true)
+
+ // determine the input columns: these need to be passed through
+ val origCols = dataset.schema.map(f => col(f.name))
+
+ // add an accumulator column to store predictions of all the models
+ val accColName = "mbc$acc" + UUID.randomUUID().toString
+ val init: () => Map[Int, Double] = () => {Map()}
+ val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
+ val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
+
+ // persist if underlying dataset is not persistent.
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ newDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // update the accumulator column with the result of prediction of models
+ val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
+ case (df, (model, index)) =>
+ val rawPredictionCol = model.getRawPredictionCol
+ val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
+
+ // add temporary column to store intermediate scores and update
+ val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
+ val update: (Map[Int, Double], Vector) => Map[Int, Double] =
+ (predictions: Map[Int, Double], prediction: Vector) => {
+ predictions + ((index, prediction(1)))
+ }
+ val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
+ val transformedDataset = model.transform(df).select(columns:_*)
+ val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
+ val newColumns = origCols ++ List(col(tmpColName))
+
+ // switch out the intermediate column with the accumulator column
+ updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
+ }
+
+ if (handlePersistence) {
+ newDataset.unpersist()
+ }
+
+ // output the index of the classifier with highest confidence as prediction
+ val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
+ predictions.maxBy(_._2)._1.toDouble
+ }
+
+ // output label and label metadata as prediction
+ val labelUdf = callUDF(label, DoubleType, col(accColName))
+ aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Reduction of Multiclass Classification to Binary Classification.
+ * Performs reduction using one against all strategy.
+ * For a multiclass classification with k classes, train k models (one per class).
+ * Each example is scored against all k models and the model with highest score
+ * is picked to label the example.
+ */
+@Experimental
+final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
+
+ /** @group setParam */
+ def setClassifier(value: Classifier[_,_,_]): this.type = {
+ // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
+ set(classifier, value.asInstanceOf[ClassifierType])
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
+ }
+
+ override def fit(dataset: DataFrame): OneVsRestModel = {
+ // determine number of classes either from metadata if provided, or via computation.
+ val labelSchema = dataset.schema($(labelCol))
+ val computeNumClasses: () => Int = () => {
+ val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ // classes are assumed to be numbered from 0,...,maxLabelIndex
+ maxLabelIndex.toInt + 1
+ }
+ val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
+
+ val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
+
+ // persist if underlying dataset is not persistent.
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // create k columns, one for each binary classifier.
+ val models = Range(0, numClasses).par.map { index =>
+
+ val label: Double => Double = (label: Double) => {
+ if (label.toInt == index) 1.0 else 0.0
+ }
+
+ // generate new label metadata for the binary problem.
+ // TODO: use when ... otherwise after SPARK-7321 is merged
+ val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
+ val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
+ val labelColName = "mc2b$" + index
+ val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
+ val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
+ val classifier = getClassifier
+ classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+ }.toArray[ClassificationModel[_,_]]
+
+ if (handlePersistence) {
+ multiclassLabeled.unpersist()
+ }
+
+ // extract label metadata from label column if present, or create a nominal attribute
+ // to output the number of labels
+ val labelAttribute = Attribute.fromStructField(labelSchema) match {
+ case _: NumericAttribute | UnresolvedAttribute =>
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ case attr: Attribute => attr
+ }
+ copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
deleted file mode 100644
index 0a6728e..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
+++ /dev/null
@@ -1,211 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.reduction
-
-import java.util.UUID
-
-import scala.language.existentials
-
-import org.apache.spark.annotation.{AlphaComponent, Experimental}
-import org.apache.spark.ml._
-import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
-import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
-import org.apache.spark.storage.StorageLevel
-
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
- type ClassifierType = Classifier[F, E, M] forSome {
- type F
- type M <: ClassificationModel[F, M]
- type E <: Classifier[F, E, M]
- }
-
- /**
- * param for the base binary classifier that we reduce multiclass classification into.
- * @group param
- */
- val classifier: Param[ClassifierType] =
- new Param(this, "classifier", "base binary classifier ")
-
- /** @group getParam */
- def getClassifier: ClassifierType = $(classifier)
-
-}
-
-/**
- * Model produced by [[OneVsRest]].
- * Stores the models resulting from training k different classifiers:
- * one for each class.
- * Each example is scored against all k models and the model with highest score
- * is picked to label the example.
- * TODO: API may need to change when we introduce a ClassificationModel trait as the public API
- * @param parent
- * @param labelMetadata Metadata of label column if it exists, or Nominal attribute
- * representing the number of classes in training dataset otherwise.
- * @param models the binary classification models for reduction.
- * The i-th model is produced by testing the i-th class vs the rest.
- */
-@AlphaComponent
-class OneVsRestModel(
- override val parent: OneVsRest,
- labelMetadata: Metadata,
- val models: Array[_ <: ClassificationModel[_,_]])
- extends Model[OneVsRestModel] with OneVsRestParams {
-
- override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
- }
-
- override def transform(dataset: DataFrame): DataFrame = {
- // Check schema
- transformSchema(dataset.schema, logging = true)
-
- // determine the input columns: these need to be passed through
- val origCols = dataset.schema.map(f => col(f.name))
-
- // add an accumulator column to store predictions of all the models
- val accColName = "mbc$acc" + UUID.randomUUID().toString
- val init: () => Map[Int, Double] = () => {Map()}
- val mapType = MapType(IntegerType, DoubleType, false)
- val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
-
- // persist if underlying dataset is not persistent.
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
- if (handlePersistence) {
- newDataset.persist(StorageLevel.MEMORY_AND_DISK)
- }
-
- // update the accumulator column with the result of prediction of models
- val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
- case (df, (model, index)) => {
- val rawPredictionCol = model.getRawPredictionCol
- val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
-
- // add temporary column to store intermediate scores and update
- val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
- val update: (Map[Int, Double], Vector) => Map[Int, Double] =
- (predictions: Map[Int, Double], prediction: Vector) => {
- predictions + ((index, prediction(1)))
- }
- val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
- val transformedDataset = model.transform(df).select(columns:_*)
- val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
- val newColumns = origCols ++ List(col(tmpColName))
-
- // switch out the intermediate column with the accumulator column
- updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
- }
- }
-
- if (handlePersistence) {
- newDataset.unpersist()
- }
-
- // output the index of the classifier with highest confidence as prediction
- val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
- predictions.maxBy(_._2)._1.toDouble
- }
-
- // output label and label metadata as prediction
- val labelUdf = callUDF(label, DoubleType, col(accColName))
- aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
- }
-}
-
-/**
- * :: Experimental ::
- *
- * Reduction of Multiclass Classification to Binary Classification.
- * Performs reduction using one against all strategy.
- * For a multiclass classification with k classes, train k models (one per class).
- * Each example is scored against all k models and the model with highest score
- * is picked to label the example.
- */
-@Experimental
-final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
-
- /** @group setParam */
- // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
- def setClassifier(value: Classifier[_,_,_]): this.type = {
- set(classifier, value.asInstanceOf[ClassifierType])
- }
-
- override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
- }
-
- override def fit(dataset: DataFrame): OneVsRestModel = {
- // determine number of classes either from metadata if provided, or via computation.
- val labelSchema = dataset.schema($(labelCol))
- val computeNumClasses: () => Int = () => {
- val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
- // classes are assumed to be numbered from 0,...,maxLabelIndex
- maxLabelIndex.toInt + 1
- }
- val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
-
- val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
-
- // persist if underlying dataset is not persistent.
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
- if (handlePersistence) {
- multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
- }
-
- // create k columns, one for each binary classifier.
- val models = Range(0, numClasses).par.map { index =>
-
- val label: Double => Double = (label: Double) => {
- if (label.toInt == index) 1.0 else 0.0
- }
-
- // generate new label metadata for the binary problem.
- // TODO: use when ... otherwise after SPARK-7321 is merged
- val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
- val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
- val labelColName = "mc2b$" + index
- val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
- val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
- val classifier = getClassifier
- classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
- }.toArray[ClassificationModel[_,_]]
-
- if (handlePersistence) {
- multiclassLabeled.unpersist()
- }
-
- // extract label metadata from label column if present, or create a nominal attribute
- // to output the number of labels
- val labelAttribute = Attribute.fromStructField(labelSchema) match {
- case _: NumericAttribute | UnresolvedAttribute => {
- NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
- }
- case attr: Attribute => attr
- }
- copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
new file mode 100644
index 0000000..a1ee554
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import static scala.collection.JavaConversions.seqAsJavaList;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaOneVsRestSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
+ jsql = new SQLContext(jsc);
+ int nPoints = 3;
+
+ // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
+ double[] weights = {
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+
+ double[] xMean = {5.843, 3.057, 3.758, 1.199};
+ double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+ List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42));
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void oneVsRestDefaultParams() {
+ OneVsRest ova = new OneVsRest();
+ ova.setClassifier(new LogisticRegression());
+ Assert.assertEquals(ova.getLabelCol() , "label");
+ Assert.assertEquals(ova.getPredictionCol() , "prediction");
+ OneVsRestModel ovaModel = ova.fit(dataset);
+ DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ predictions.collectAsList();
+ Assert.assertEquals(ovaModel.getLabelCol(), "label");
+ Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
deleted file mode 100644
index 40a90ae..0000000
--- a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.reduction;
-
-import java.io.Serializable;
-import java.util.List;
-
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-import static scala.collection.JavaConversions.seqAsJavaList;
-
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegression;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
-
-public class JavaOneVsRestSuite implements Serializable {
-
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
- private transient DataFrame dataset;
- private transient JavaRDD<LabeledPoint> datasetRDD;
-
- @Before
- public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
- jsql = new SQLContext(jsc);
- int nPoints = 3;
-
- /**
- * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
- * As a result, we are actually drawing samples from probability distribution of built model.
- */
- double[] weights = {
- -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
- -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
-
- double[] xMean = {5.843, 3.057, 3.758, 1.199};
- double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
- List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
- weights, xMean, xVariance, true, nPoints, 42));
- datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
- }
-
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
-
- @Test
- public void oneVsRestDefaultParams() {
- OneVsRest ova = new OneVsRest();
- ova.setClassifier(new LogisticRegression());
- Assert.assertEquals(ova.getLabelCol() , "label");
- Assert.assertEquals(ova.getPredictionCol() , "prediction");
- OneVsRestModel ovaModel = ova.fit(dataset);
- DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
- predictions.collectAsList();
- Assert.assertEquals(ovaModel.getLabelCol(), "label");
- Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
new file mode 100644
index 0000000..e65ffae
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+ @transient var rdd: RDD[LabeledPoint] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ val nPoints = 1000
+
+ // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
+ val weights = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ rdd = sc.parallelize(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42), 2)
+ dataset = sqlContext.createDataFrame(rdd)
+ }
+
+ test("one-vs-rest: default params") {
+ val numClasses = 3
+ val ova = new OneVsRest()
+ ova.setClassifier(new LogisticRegression)
+ assert(ova.getLabelCol === "label")
+ assert(ova.getPredictionCol === "prediction")
+ val ovaModel = ova.fit(dataset)
+ assert(ovaModel.models.size === numClasses)
+ val transformedDataset = ovaModel.transform(dataset)
+
+ // check for label metadata in prediction col
+ val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
+ assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
+
+ val ovaResults = transformedDataset
+ .select("prediction", "label")
+ .map(row => (row.getDouble(0), row.getDouble(1)))
+
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
+ lr.optimizer.setRegParam(0.1).setNumIterations(100)
+
+ val model = lr.run(rdd)
+ val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
+ // determine the #confusion matrix in each class.
+ // bound how much error we allow compared to multinomial logistic regression.
+ val expectedMetrics = new MulticlassMetrics(results)
+ val ovaMetrics = new MulticlassMetrics(ovaResults)
+ assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
+ }
+
+ test("one-vs-rest: pass label metadata correctly during train") {
+ val numClasses = 3
+ val ova = new OneVsRest()
+ ova.setClassifier(new MockLogisticRegression)
+
+ val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
+ val features = dataset("features").as("features")
+ val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
+ ova.fit(datasetWithLabelMetadata)
+ }
+}
+
+private class MockLogisticRegression extends LogisticRegression {
+
+ setMaxIter(1)
+
+ override protected def train(dataset: DataFrame): LogisticRegressionModel = {
+ val labelSchema = dataset.schema($(labelCol))
+ // check for label attribute propagation.
+ assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
+ super.train(dataset)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96c4846d/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
deleted file mode 100644
index ebec7c6..0000000
--- a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.reduction
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
-import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.classification.LogisticRegressionSuite._
-import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
-
-class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
-
- @transient var sqlContext: SQLContext = _
- @transient var dataset: DataFrame = _
- @transient var rdd: RDD[LabeledPoint] = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- val nPoints = 1000
-
- /**
- * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
- * As a result, we are actually drawing samples from probability distribution of built model.
- */
- val weights = Array(
- -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
- -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
-
- val xMean = Array(5.843, 3.057, 3.758, 1.199)
- val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
- rdd = sc.parallelize(generateMultinomialLogisticInput(
- weights, xMean, xVariance, true, nPoints, 42), 2)
- dataset = sqlContext.createDataFrame(rdd)
- }
-
- test("one-vs-rest: default params") {
- val numClasses = 3
- val ova = new OneVsRest()
- ova.setClassifier(new LogisticRegression)
- assert(ova.getLabelCol === "label")
- assert(ova.getPredictionCol === "prediction")
- val ovaModel = ova.fit(dataset)
- assert(ovaModel.models.size === numClasses)
- val transformedDataset = ovaModel.transform(dataset)
-
- // check for label metadata in prediction col
- val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
- assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
-
- val ovaResults = transformedDataset
- .select("prediction", "label")
- .map(row => (row.getDouble(0), row.getDouble(1)))
-
- val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
- lr.optimizer.setRegParam(0.1).setNumIterations(100)
-
- val model = lr.run(rdd)
- val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
- // determine the #confusion matrix in each class.
- // bound how much error we allow compared to multinomial logistic regression.
- val expectedMetrics = new MulticlassMetrics(results)
- val ovaMetrics = new MulticlassMetrics(ovaResults)
- assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
- }
-
- test("one-vs-rest: pass label metadata correctly during train") {
- val numClasses = 3
- val ova = new OneVsRest()
- ova.setClassifier(new MockLogisticRegression)
-
- val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
- val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
- val features = dataset("features").as("features")
- val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
- ova.fit(datasetWithLabelMetadata)
- }
-}
-
-private class MockLogisticRegression extends LogisticRegression {
-
- setMaxIter(1)
-
- override protected def train(dataset: DataFrame): LogisticRegressionModel = {
- val labelSchema = dataset.schema($(labelCol))
- // check for label attribute propagation.
- assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
- super.train(dataset)
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org