You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2018/01/26 10:23:20 UTC
spark git commit: [SPARK-22799][ML] Bucketizer should throw exception
if single- and multi-column params are both set
Repository: spark
Updated Branches:
refs/heads/master d1721816d -> cd3956df0
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set
## What changes were proposed in this pull request?
Currently there is a mixed situation when both single- and multi-column are supported. In some cases exceptions are thrown, in others only a warning log is emitted. In this discussion https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049, the decision was to throw an exception.
The PR throws an exception in `Bucketizer`, instead of logging a warning.
## How was this patch tested?
modified UT
Author: Marco Gaido <ma...@gmail.com>
Author: Joseph K. Bradley <jo...@databricks.com>
Closes #19993 from mgaido91/SPARK-22799.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd3956df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd3956df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd3956df
Branch: refs/heads/master
Commit: cd3956df0f96dd416b6161bf7ce2962e06d0a62e
Parents: d172181
Author: Marco Gaido <ma...@gmail.com>
Authored: Fri Jan 26 12:23:14 2018 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Fri Jan 26 12:23:14 2018 +0200
----------------------------------------------------------------------
.../apache/spark/ml/feature/Bucketizer.scala | 44 ++++++-------
.../org/apache/spark/ml/param/params.scala | 69 ++++++++++++++++++++
.../spark/ml/feature/BucketizerSuite.scala | 41 ++++++------
.../org/apache/spark/ml/param/ParamsSuite.scala | 22 +++++++
4 files changed, 131 insertions(+), 45 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 8299a3e..c13bf47 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
- * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ *
+ * Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
- * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
- * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
- * only used for single column usage, and `splitsArray` is for multiple columns.
+ * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
+ * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
+ * columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
- /**
- * Determines whether this `Bucketizer` is going to map multiple columns. If and only if
- * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
- * by `inputCol`. A warning will be printed if both are set.
- */
- private[feature] def isBucketizeMultipleColumns(): Boolean = {
- if (isSet(inputCols) && isSet(inputCol)) {
- logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
- "`Bucketizer` only map one column specified by `inputCol`")
- false
- } else if (isSet(inputCols)) {
- true
- } else {
- false
- }
- }
-
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema)
- val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
+ val (inputColumns, outputColumns) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
@@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}
- val seqOfSplits = if (isBucketizeMultipleColumns()) {
+ val seqOfSplits = if (isSet(inputCols)) {
$(splitsArray).toSeq
} else {
Seq($(splits))
@@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- if (isBucketizeMultipleColumns()) {
+ ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
+ Seq(outputCols, splitsArray))
+
+ if (isSet(inputCols)) {
+ require(getInputCols.length == getOutputCols.length &&
+ getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
+ s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
+ s"equal lengths, but they have different lengths: " +
+ s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")
+
var transformedSchema = schema
- $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
+ $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 1b4b401..9a83a58 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -249,6 +249,75 @@ object ParamValidators {
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
value.length > lowerBound
}
+
+ /**
+ * Utility for Param validity checks for Transformers which have both single- and multi-column
+ * support. This utility assumes that `inputCol` indicates single-column usage and
+ * that `inputCols` indicates multi-column usage.
+ *
+ * This checks to ensure that exactly one set of Params has been set, and it
+ * raises an `IllegalArgumentException` if not.
+ *
+ * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been
+ * set. This does not need to include `inputCol`.
+ * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been
+ * set. This does not need to include `inputCols`.
+ */
+ def checkSingleVsMultiColumnParams(
+ model: Params,
+ singleColumnParams: Seq[Param[_]],
+ multiColumnParams: Seq[Param[_]]): Unit = {
+ val name = s"${model.getClass.getSimpleName} $model"
+
+ def checkExclusiveParams(
+ isSingleCol: Boolean,
+ requiredParams: Seq[Param[_]],
+ excludedParams: Seq[Param[_]]): Unit = {
+ val badParamsMsgBuilder = new mutable.StringBuilder()
+
+ val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
+ .map(_.name).mkString(", ")
+ if (mustUnsetParams.nonEmpty) {
+ badParamsMsgBuilder ++=
+ s"The following Params are not applicable and should not be set: $mustUnsetParams."
+ }
+
+ val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
+ .map(_.name).mkString(", ")
+ if (mustSetParams.nonEmpty) {
+ badParamsMsgBuilder ++=
+ s"The following Params must be defined but are not set: $mustSetParams."
+ }
+
+ val badParamsMsg = badParamsMsgBuilder.toString()
+
+ if (badParamsMsg.nonEmpty) {
+ val errPrefix = if (isSingleCol) {
+ s"$name has the inputCol Param set for single-column transform."
+ } else {
+ s"$name has the inputCols Param set for multi-column transform."
+ }
+ throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
+ }
+ }
+
+ val inputCol = model.getParam("inputCol")
+ val inputCols = model.getParam("inputCols")
+
+ if (model.isSet(inputCol)) {
+ require(!model.isSet(inputCols), s"$name requires " +
+ s"exactly one of inputCol, inputCols Params to be set, but both are set.")
+
+ checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
+ excludedParams = multiColumnParams)
+ } else if (model.isSet(inputCols)) {
+ checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
+ excludedParams = singleColumnParams)
+ } else {
+ throw new IllegalArgumentException(s"$name requires " +
+ s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
+ }
+ }
}
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index d9c97ae..7403680 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)
- assert(bucketizer1.isBucketizeMultipleColumns())
-
bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
Seq("result1", "result2"),
@@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result"))
.setSplitsArray(Array(splits(0)))
- assert(bucketizer2.isBucketizeMultipleColumns())
-
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer2.transform(badDF1).collect()
@@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)
- assert(bucketizer.isBucketizeMultipleColumns())
-
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
Seq("expected1", "expected2"))
@@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)
- assert(bucketizer.isBucketizeMultipleColumns())
-
bucketizer.setHandleInvalid("keep")
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
@@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
- assert(t.isBucketizeMultipleColumns())
testDefaultReadWrite(t)
}
@@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
- assert(bucket.isBucketizeMultipleColumns())
-
val pl = new Pipeline()
.setStages(Array(bucket))
.fit(df)
@@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}
- test("Both inputCol and inputCols are set") {
- val bucket = new Bucketizer()
- .setInputCol("feature1")
- .setOutputCol("result")
- .setSplits(Array(-0.5, 0.0, 0.5))
- .setInputCols(Array("feature1", "feature2"))
-
- // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
- assert(bucket.isBucketizeMultipleColumns() == false)
+ test("assert exception is thrown if both multi-column and single-column params are set") {
+ val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
+ ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+ ("inputCols", Array("feature1", "feature2")))
+ ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+ ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
+ ("outputCols", Array("result1", "result2")))
+ ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+ ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
+ ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))))
+
+ // this should fail because at least one of inputCol and inputCols must be set
+ ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"),
+ ("splits", Array(-0.5, 0.0, 0.5)))
+
+ // the following should fail because not all the params are set
+ ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+ ("outputCol", "result1"))
+ ParamsSuite.testExclusiveParams(new Bucketizer, df,
+ ("inputCols", Array("feature1", "feature2")),
+ ("outputCols", Array("result1", "result2")))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 85198ad..36e0609 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.ml.param
import java.io.{ByteArrayOutputStream, ObjectOutputStream}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.MyParams
+import org.apache.spark.sql.Dataset
class ParamsSuite extends SparkFunSuite {
@@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite {
require(copyReturnType === obj.getClass,
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}
+
+ /**
+ * Checks that the class throws an exception in case multiple exclusive params are set.
+ * The params to be checked are passed as arguments with their value.
+ */
+ def testExclusiveParams(
+ model: Params,
+ dataset: Dataset[_],
+ paramsAndValues: (String, Any)*): Unit = {
+ val m = model.copy(ParamMap.empty)
+ paramsAndValues.foreach { case (paramName, paramValue) =>
+ m.set(m.getParam(paramName), paramValue)
+ }
+ intercept[IllegalArgumentException] {
+ m match {
+ case t: Transformer => t.transform(dataset)
+ case e: Estimator[_] => e.fit(dataset)
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org