You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/11/01 17:46:57 UTC

spark git commit: [SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit

Repository: spark
Updated Branches:
  refs/heads/master 0cba535af -> 8ac09108f


[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit

## What changes were proposed in this pull request?

1, move cast to `Predictor`
2, and then, remove unnecessary cast
## How was this patch tested?

existing tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #15414 from zhengruifeng/move_cast.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ac09108
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ac09108
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ac09108

Branch: refs/heads/master
Commit: 8ac09108fcf3fb62a812333a5b386b566a9d98ec
Parents: 0cba535
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Tue Nov 1 10:46:36 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Nov 1 10:46:36 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   | 12 ++-
 .../spark/ml/classification/Classifier.scala    |  4 +-
 .../spark/ml/classification/GBTClassifier.scala |  2 +-
 .../ml/classification/LogisticRegression.scala  |  2 +-
 .../spark/ml/classification/NaiveBayes.scala    |  2 +-
 .../GeneralizedLinearRegression.scala           |  2 +-
 .../spark/ml/regression/LinearRegression.scala  |  2 +-
 .../org/apache/spark/ml/PredictorSuite.scala    | 82 ++++++++++++++++++++
 .../LogisticRegressionSuite.scala               |  1 -
 9 files changed, 98 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e29d7f4..aa92edd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
 
 /**
  * :: DeveloperApi ::
- * Abstraction for prediction problems (regression and classification).
+ * Abstraction for prediction problems (regression and classification). It accepts all NumericType
+ * labels and will automatically cast it to DoubleType in [[fit()]].
  *
  * @tparam FeaturesType  Type of features.
  *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@@ -87,7 +88,12 @@ abstract class Predictor[
     // This handles a few items such as schema validation.
     // Developers only need to implement train().
     transformSchema(dataset.schema, logging = true)
-    copyValues(train(dataset).setParent(this))
+
+    // Cast LabelCol to DoubleType and keep the metadata.
+    val labelMeta = dataset.schema($(labelCol)).metadata
+    val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+    copyValues(train(casted).setParent(this))
   }
 
   override def copy(extra: ParamMap): Learner
@@ -121,7 +127,7 @@ abstract class Predictor[
    * and put it in an RDD with strong types.
    */
   protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
-    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
       case Row(label: Double, features: Vector) => LabeledPoint(label, features)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d1b21b1..a3da306 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -71,7 +71,7 @@ abstract class Classifier[
    * and put it in an RDD with strong types.
    *
    * @param dataset  DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
-   *                 and features ([[Vector]]). Labels are cast to [[DoubleType]].
+   *                 and features ([[Vector]]).
    * @param numClasses  Number of classes label can take.  Labels must be integers in the range
    *                    [0, numClasses).
    * @throws SparkException  if any label is not an integer >= 0
@@ -79,7 +79,7 @@ abstract class Classifier[
   protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
     require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
       s" $numClasses, but requires numClasses > 0.")
-    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
       case Row(label: Double, features: Vector) =>
         require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
           s" dataset with invalid label $label.  Labels must be integers in range" +

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 8bffe0c..f8f164e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
     // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
     // 2 classes now.  This lets us provide a more precise error message.
     val oldDataset: RDD[LabeledPoint] =
-      dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
         case Row(label: Double, features: Vector) =>
           require(label == 0 || label == 1, s"GBTClassifier was given" +
             s" dataset with invalid label $label.  Labels must be in {0,1}; note that" +

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fdaae0..c465105 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
       LogisticRegressionModel = {
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
         case Row(label: Double, weight: Double, features: Vector) =>
           Instance(label, weight, features)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 994ed99..b03a07a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
     // Aggregates term frequencies per label.
     // TODO: Calling aggregateByKey and collect creates two stages, we can implement something
     // TODO: similar to reduceByKeyLocally to save one stage.
-    val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
+    val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
       .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
       }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
       seqOp = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 33cb25c..8656ecf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
 
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
         case Row(label: Double, weight: Double, features: Vector) =>
           Instance(label, weight, features)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 519f3bd..ae876b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
 
     val instances: RDD[Instance] = dataset.select(
-      col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      col($(labelCol)), w, col($(featuresCol))).rdd.map {
       case Row(label: Double, weight: Double, features: Vector) =>
         Instance(label, weight, features)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
new file mode 100644
index 0000000..03e0c53
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import PredictorSuite._
+
+  test("should support all NumericType labels and not support other types") {
+    val df = spark.createDataFrame(Seq(
+      (0, Vectors.dense(0, 2, 3)),
+      (1, Vectors.dense(0, 3, 9)),
+      (0, Vectors.dense(0, 2, 6))
+    )).toDF("label", "features")
+
+    val types =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+
+    val predictor = new MockPredictor()
+
+    types.foreach { t =>
+      predictor.fit(df.select(col("label").cast(t), col("features")))
+    }
+
+    intercept[IllegalArgumentException] {
+      predictor.fit(df.select(col("label").cast(StringType), col("features")))
+    }
+  }
+}
+
+object PredictorSuite {
+
+  class MockPredictor(override val uid: String)
+    extends Predictor[Vector, MockPredictor, MockPredictionModel] {
+
+    def this() = this(Identifiable.randomUID("mockpredictor"))
+
+    override def train(dataset: Dataset[_]): MockPredictionModel = {
+      require(dataset.schema("label").dataType == DoubleType)
+      new MockPredictionModel(uid)
+    }
+
+    override def copy(extra: ParamMap): MockPredictor =
+      throw new NotImplementedError()
+  }
+
+  class MockPredictionModel(override val uid: String)
+    extends PredictionModel[Vector, MockPredictionModel] {
+
+    def this() = this(Identifiable.randomUID("mockpredictormodel"))
+
+    override def predict(features: Vector): Double =
+      throw new NotImplementedError()
+
+    override def copy(extra: ParamMap): MockPredictionModel =
+      throw new NotImplementedError()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index bc631dc..8771fd2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
         .objectiveHistory
         .sliding(2)
         .forall(x => x(0) >= x(1)))
-
   }
 
   test("binary logistic regression with weighted data") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org