You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/11/01 17:46:57 UTC
spark git commit: [SPARK-17848][ML] Move LabelCol datatype cast into
Predictor.fit
Repository: spark
Updated Branches:
refs/heads/master 0cba535af -> 8ac09108f
[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit
## What changes were proposed in this pull request?
1, move cast to `Predictor`
2, and then, remove unnecessary cast
## How was this patch tested?
existing tests
Author: Zheng RuiFeng <ru...@foxmail.com>
Closes #15414 from zhengruifeng/move_cast.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ac09108
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ac09108
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ac09108
Branch: refs/heads/master
Commit: 8ac09108fcf3fb62a812333a5b386b566a9d98ec
Parents: 0cba535
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Tue Nov 1 10:46:36 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Nov 1 10:46:36 2016 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/ml/Predictor.scala | 12 ++-
.../spark/ml/classification/Classifier.scala | 4 +-
.../spark/ml/classification/GBTClassifier.scala | 2 +-
.../ml/classification/LogisticRegression.scala | 2 +-
.../spark/ml/classification/NaiveBayes.scala | 2 +-
.../GeneralizedLinearRegression.scala | 2 +-
.../spark/ml/regression/LinearRegression.scala | 2 +-
.../org/apache/spark/ml/PredictorSuite.scala | 82 ++++++++++++++++++++
.../LogisticRegressionSuite.scala | 1 -
9 files changed, 98 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e29d7f4..aa92edd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
/**
* :: DeveloperApi ::
- * Abstraction for prediction problems (regression and classification).
+ * Abstraction for prediction problems (regression and classification). It accepts all NumericType
+ * labels and will automatically cast it to DoubleType in [[fit()]].
*
* @tparam FeaturesType Type of features.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@@ -87,7 +88,12 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
- copyValues(train(dataset).setParent(this))
+
+ // Cast LabelCol to DoubleType and keep the metadata.
+ val labelMeta = dataset.schema($(labelCol)).metadata
+ val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+ copyValues(train(casted).setParent(this))
}
override def copy(extra: ParamMap): Learner
@@ -121,7 +127,7 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d1b21b1..a3da306 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -71,7 +71,7 @@ abstract class Classifier[
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
- * and features ([[Vector]]). Labels are cast to [[DoubleType]].
+ * and features ([[Vector]]).
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @throws SparkException if any label is not an integer >= 0
@@ -79,7 +79,7 @@ abstract class Classifier[
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 8bffe0c..f8f164e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val oldDataset: RDD[LabeledPoint] =
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fdaae0..c465105 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 994ed99..b03a07a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
+ val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
seqOp = {
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 33cb25c..8656ecf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 519f3bd..ae876b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
new file mode 100644
index 0000000..03e0c53
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import PredictorSuite._
+
+ test("should support all NumericType labels and not support other types") {
+ val df = spark.createDataFrame(Seq(
+ (0, Vectors.dense(0, 2, 3)),
+ (1, Vectors.dense(0, 3, 9)),
+ (0, Vectors.dense(0, 2, 6))
+ )).toDF("label", "features")
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+
+ val predictor = new MockPredictor()
+
+ types.foreach { t =>
+ predictor.fit(df.select(col("label").cast(t), col("features")))
+ }
+
+ intercept[IllegalArgumentException] {
+ predictor.fit(df.select(col("label").cast(StringType), col("features")))
+ }
+ }
+}
+
+object PredictorSuite {
+
+ class MockPredictor(override val uid: String)
+ extends Predictor[Vector, MockPredictor, MockPredictionModel] {
+
+ def this() = this(Identifiable.randomUID("mockpredictor"))
+
+ override def train(dataset: Dataset[_]): MockPredictionModel = {
+ require(dataset.schema("label").dataType == DoubleType)
+ new MockPredictionModel(uid)
+ }
+
+ override def copy(extra: ParamMap): MockPredictor =
+ throw new NotImplementedError()
+ }
+
+ class MockPredictionModel(override val uid: String)
+ extends PredictionModel[Vector, MockPredictionModel] {
+
+ def this() = this(Identifiable.randomUID("mockpredictormodel"))
+
+ override def predict(features: Vector): Double =
+ throw new NotImplementedError()
+
+ override def copy(extra: ParamMap): MockPredictionModel =
+ throw new NotImplementedError()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/8ac09108/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index bc631dc..8771fd2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
-
}
test("binary logistic regression with weighted data") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org