You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sa...@apache.org on 2015/12/28 21:33:27 UTC

spark git commit: [SPARK-12424][ML] The implementation of ParamMap#filter is wrong.

Repository: spark
Updated Branches:
  refs/heads/master e01c6c866 -> 07165ca06


[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.

ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`.
Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654).

Author: Kousuke Saruta <sa...@oss.nttdata.co.jp>

Closes #10381 from sarutak/SPARK-12424.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/07165ca0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/07165ca0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/07165ca0

Branch: refs/heads/master
Commit: 07165ca06fe0866677525f85fec25e4dbd336674
Parents: e01c6c8
Author: Kousuke Saruta <sa...@oss.nttdata.co.jp>
Authored: Tue Dec 29 05:33:19 2015 +0900
Committer: Kousuke Saruta <sa...@oss.nttdata.co.jp>
Committed: Tue Dec 29 05:33:19 2015 +0900

----------------------------------------------------------------------
 .../org/apache/spark/ml/param/params.scala      |  8 ++++--
 .../org/apache/spark/ml/param/ParamsSuite.scala | 28 ++++++++++++++++++++
 2 files changed, 34 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/07165ca0/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ee7e89e..c054669 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
    * Filters this param map for the given parent.
    */
   def filter(parent: Params): ParamMap = {
-    val filtered = map.filterKeys(_.parent == parent)
-    new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
+    // Don't use filterKeys because mutable.Map#filterKeys
+    // returns the instance of collections.Map, not mutable.Map.
+    // Otherwise, we get ClassCastException.
+    // Not using filterKeys also avoid SI-6654
+    val filtered = map.filter { case (k, _) => k.parent == parent.uid }
+    new ParamMap(filtered)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/07165ca0/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index a1878be..7488685 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.ml.param
 
+import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream}
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MyParams
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 
 class ParamsSuite extends SparkFunSuite {
@@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
     val t3 = t.copy(ParamMap(t.maxIter -> 20))
     assert(t3.isSet(t3.maxIter))
   }
+
+  test("Filtering ParamMap") {
+    val params1 = new MyParams("my_params1")
+    val params2 = new MyParams("my_params2")
+    val paramMap = ParamMap(
+      params1.intParam -> 1,
+      params2.intParam -> 1,
+      params1.doubleParam -> 0.2,
+      params2.doubleParam -> 0.2)
+    val filteredParamMap = paramMap.filter(params1)
+
+    assert(filteredParamMap.size === 2)
+    filteredParamMap.toSeq.foreach {
+      case ParamPair(p, _) =>
+        assert(p.parent === params1.uid)
+    }
+
+    // At the previous implementation of ParamMap#filter,
+    // mutable.Map#filterKeys was used internally but
+    // the return type of the method is not serializable (see SI-6654).
+    // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
+    // So let's ensure serializability.
+    val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
+    objOut.writeObject(filteredParamMap)
+  }
 }
 
 object ParamsSuite extends SparkFunSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org