You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/08/08 03:18:59 UTC
spark git commit: [SPARK-21306][ML] For branch 2.0,
OneVsRest should support setWeightCol
Repository: spark
Updated Branches:
refs/heads/branch-2.0 c27a01aec -> 9f670ce5d
[SPARK-21306][ML] For branch 2.0, OneVsRest should support setWeightCol
The PR is related to #18554, and is modified for branch 2.0.
## What changes were proposed in this pull request?
add `setWeightCol` method for OneVsRest.
`weightCol` is ignored if classifier doesn't inherit HasWeightCol trait.
## How was this patch tested?
+ [x] add an unit test.
Author: Yan Facai (颜发才) <fa...@gmail.com>
Closes #18764 from facaiy/BUG/branch-2.0_OneVsRest_support_setWeightCol.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9f670ce5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9f670ce5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9f670ce5
Branch: refs/heads/branch-2.0
Commit: 9f670ce5d1aeef737226185d78f07147f0cc2693
Parents: c27a01a
Author: Yan Facai (颜发才) <fa...@gmail.com>
Authored: Tue Aug 8 11:18:15 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Tue Aug 8 11:18:15 2017 +0800
----------------------------------------------------------------------
.../spark/ml/classification/OneVsRest.scala | 39 ++++++++++++++++++--
.../ml/classification/OneVsRestSuite.scala | 11 ++++++
python/pyspark/ml/classification.py | 27 +++++++++++---
python/pyspark/ml/tests.py | 14 +++++++
4 files changed, 82 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9f670ce5/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f4ab0a0..770d5db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
-private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
+private[ml] trait OneVsRestParams extends PredictorParams
+ with ClassifierTypeTrait with HasWeightCol {
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -290,6 +292,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ *
+ * This is ignored if weight is not supported by [[classifier]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.3.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
@@ -308,7 +322,20 @@ final class OneVsRest @Since("1.4.0") (
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
- val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
+ val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
+ getClassifier match {
+ case _: HasWeightCol => true
+ case c =>
+ logWarning(s"weightCol is ignored, as it is not supported by $c now.")
+ false
+ }
+ }
+
+ val multiclassLabeled = if (weightColIsUsed) {
+ dataset.select($(labelCol), $(featuresCol), $(weightCol))
+ } else {
+ dataset.select($(labelCol), $(featuresCol))
+ }
// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -328,7 +355,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
- classifier.fit(trainingDataset, paramMap)
+ if (weightColIsUsed) {
+ val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
+ paramMap.put(classifier_.weightCol -> getWeightCol)
+ classifier_.fit(trainingDataset, paramMap)
+ } else {
+ classifier.fit(trainingDataset, paramMap)
+ }
}.toArray[ClassificationModel[_, _]]
if (handlePersistence) {
http://git-wip-us.apache.org/repos/asf/spark/blob/9f670ce5/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 361dd74..a266704 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -143,6 +144,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
+ test("SPARK-21306: OneVsRest should support setWeightCol") {
+ val dataset2 = dataset.withColumn("weight", lit(1.0))
+ // classifier inherits hasWeightCol
+ val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
+ assert(ova.fit(dataset2) !== null)
+ // classifier doesn't inherit hasWeightCol
+ val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
+ assert(ova2.fit(dataset2) !== null)
+ }
+
test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
http://git-wip-us.apache.org/repos/asf/spark/blob/9f670ce5/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 0a30321..7b3bd3b 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1252,7 +1252,7 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR
return self._call_java("weights")
-class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
+class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
@@ -1315,10 +1315,10 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
- classifier=None):
+ classifier=None, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
- classifier=None)
+ classifier=None, weightCol=None)
"""
super(OneVsRest, self).__init__()
kwargs = self._input_kwargs
@@ -1326,9 +1326,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
@keyword_only
@since("2.0.0")
- def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
+ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
+ classifier=None, weightCol=None):
"""
- setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
+ setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
+ classifier=None, weightCol=None):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
@@ -1344,7 +1346,18 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
- multiclassLabeled = dataset.select(labelCol, featuresCol)
+ weightCol = None
+ if (self.isDefined(self.weightCol) and self.getWeightCol()):
+ if isinstance(classifier, HasWeightCol):
+ weightCol = self.getWeightCol()
+ else:
+ warnings.warn("weightCol is ignored, "
+ "as it is not supported by {0} now.".format(classifier))
+
+ if weightCol:
+ multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
+ else:
+ multiclassLabeled = dataset.select(labelCol, featuresCol)
# persist if underlying dataset is not persistent.
handlePersistence = \
@@ -1360,6 +1373,8 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
paramMap = dict([(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol)])
+ if weightCol:
+ paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)
# TODO: Parallel training for all classes.
http://git-wip-us.apache.org/repos/asf/spark/blob/9f670ce5/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 87f0aff..aea5be7 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1128,6 +1128,20 @@ class OneVsRestTests(SparkSessionTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])
+ def test_support_for_weightCol(self):
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
+ (1.0, Vectors.sparse(2, [], []), 1.0),
+ (2.0, Vectors.dense(0.5, 0.5), 1.0)],
+ ["label", "features", "weight"])
+ # classifier inherits hasWeightCol
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr, weightCol="weight")
+ self.assertIsNotNone(ovr.fit(df))
+ # classifier doesn't inherit hasWeightCol
+ dt = DecisionTreeClassifier()
+ ovr2 = OneVsRest(classifier=dt, weightCol="weight")
+ self.assertIsNotNone(ovr2.fit(df))
+
class HashingTFTest(SparkSessionTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org