You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/02/15 17:54:43 UTC
spark git commit: [SPARK-23377][ML] Fixes Bucketizer with multiple
columns persistence bug
Repository: spark
Updated Branches:
refs/heads/master 6968c3cfd -> db45daab9
[SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug
## What changes were proposed in this pull request?
#### Problem:
Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown.
However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed.
#### Fix:
This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later.
Please see the discussion in the JIRA.
Note: The multi-column `QuantileDiscretizer` also has the same issue.
## How was this patch tested?
Modified tests.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #20594 from viirya/SPARK-23377-2.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/db45daab
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/db45daab
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/db45daab
Branch: refs/heads/master
Commit: db45daab90ede4c03c1abc9096f4eac584e9db17
Parents: 6968c3c
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Feb 15 09:54:39 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Feb 15 09:54:39 2018 -0800
----------------------------------------------------------------------
.../apache/spark/ml/feature/Bucketizer.scala | 28 ++++++++++++++++++++
.../spark/ml/feature/QuantileDiscretizer.scala | 28 ++++++++++++++++++++
.../spark/ml/feature/BucketizerSuite.scala | 12 +++++++--
.../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++--
4 files changed, 78 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/db45daab/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index c13bf47..f49c410 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -19,6 +19,10 @@ package org.apache.spark.ml.feature
import java.{util => ju}
+import org.json4s.JsonDSL._
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
@@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
+
+ override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
}
@Since("1.6.0")
@@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}
+
+ private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // SPARK-23377: The default params will be saved and loaded as user-supplied params.
+ // Once `inputCols` is set, the default value of `outputCol` param causes the error
+ // when checking exclusive params. As a temporary to fix it, we skip the default value
+ // of `outputCol` if `inputCols` is set when saving the metadata.
+ // TODO: If we modify the persistence mechanism later to better handle default params,
+ // we can get rid of this.
+ var paramWithoutOutputCol: Option[JValue] = None
+ if (instance.isSet(instance.inputCols)) {
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList
+ paramWithoutOutputCol = Some(render(jsonParams))
+ }
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
+ }
+ }
+
@Since("1.6.0")
override def load(path: String): Bucketizer = super.load(path)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/db45daab/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 1ec5f8c..3b4c254 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.feature
+import org.json4s.JsonDSL._
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
@@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("1.6.0")
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
+
+ override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
}
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
+ private[QuantileDiscretizer]
+ class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // SPARK-23377: The default params will be saved and loaded as user-supplied params.
+ // Once `inputCols` is set, the default value of `outputCol` param causes the error
+ // when checking exclusive params. As a temporary to fix it, we skip the default value
+ // of `outputCol` if `inputCols` is set when saving the metadata.
+ // TODO: If we modify the persistence mechanism later to better handle default params,
+ // we can get rid of this.
+ var paramWithoutOutputCol: Option[JValue] = None
+ if (instance.isSet(instance.inputCols)) {
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList
+ paramWithoutOutputCol = Some(render(jsonParams))
+ }
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
+ }
+ }
+
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/db45daab/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 7403680..41cf72f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setSplits(Array(0.1, 0.8, 0.9))
- testDefaultReadWrite(t)
+
+ val bucketizer = testDefaultReadWrite(t)
+ val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
+ bucketizer.transform(data)
}
test("Bucket numeric features") {
@@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
- testDefaultReadWrite(t)
+
+ val bucketizer = testDefaultReadWrite(t)
+ val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
+ bucketizer.transform(data)
+ assert(t.hasDefault(t.outputCol))
+ assert(bucketizer.hasDefault(bucketizer.outputCol))
}
test("Bucketizer in a pipeline") {
http://git-wip-us.apache.org/repos/asf/spark/blob/db45daab/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index e9a75e9..6c36379 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ import testImplicits._
+
test("Test observed number of buckets and their sizes match expected values") {
val spark = this.spark
import spark.implicits._
@@ -132,7 +134,10 @@ class QuantileDiscretizerSuite
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setNumBuckets(6)
- testDefaultReadWrite(t)
+
+ val readDiscretizer = testDefaultReadWrite(t)
+ val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol")
+ readDiscretizer.fit(data)
}
test("Verify resulting model has parent") {
@@ -379,7 +384,12 @@ class QuantileDiscretizerSuite
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBucketsArray(Array(5, 10))
- testDefaultReadWrite(discretizer)
+
+ val readDiscretizer = testDefaultReadWrite(discretizer)
+ val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2")
+ readDiscretizer.fit(data)
+ assert(discretizer.hasDefault(discretizer.outputCol))
+ assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
}
test("Multiple Columns: Both inputCol and inputCols are set") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org