You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/10/13 22:24:15 UTC

spark git commit: [SPARK-7402] [ML] JSON SerDe for standard param types

Repository: spark
Updated Branches:
  refs/heads/master c75f058b7 -> 2b574f52d


[SPARK-7402] [ML] JSON SerDe for standard param types

This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #9090 from mengxr/SPARK-7402.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b574f52
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b574f52
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b574f52

Branch: refs/heads/master
Commit: 2b574f52d7bf51b1fe2a73086a3735b633e9083f
Parents: c75f058
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Oct 13 13:24:10 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Oct 13 13:24:10 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/param/params.scala      | 169 +++++++++++++++++++
 .../org/apache/spark/ml/param/ParamsSuite.scala | 114 +++++++++++++
 2 files changed, 283 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b574f52/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ec98b05..8361406 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -24,6 +24,9 @@ import scala.annotation.varargs
 import scala.collection.mutable
 import scala.collection.JavaConverters._
 
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.ml.util.Identifiable
 
@@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
   /** Creates a param pair with the given value (for Scala). */
   def ->(value: T): ParamPair[T] = ParamPair(this, value)
 
+  /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */
+  def jsonEncode(value: T): String = {
+    value match {
+      case x: String =>
+        compact(render(JString(x)))
+      case _ =>
+        throw new NotImplementedError(
+          "The default jsonEncode only supports string. " +
+            s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
+    }
+  }
+
+  /** Decodes a param value from JSON. */
+  def jsonDecode(json: String): T = {
+    parse(json) match {
+      case JString(x) =>
+        x.asInstanceOf[T]
+      case _ =>
+        throw new NotImplementedError(
+          "The default jsonDecode only supports string. " +
+            s"${this.getClass.getName} must override jsonDecode to support its value type.")
+    }
+  }
+
   override final def toString: String = s"${parent}__$name"
 
   override final def hashCode: Int = toString.##
@@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
 
   /** Creates a param pair with the given value (for Java). */
   override def w(value: Double): ParamPair[Double] = super.w(value)
+
+  override def jsonEncode(value: Double): String = {
+    compact(render(DoubleParam.jValueEncode(value)))
+  }
+
+  override def jsonDecode(json: String): Double = {
+    DoubleParam.jValueDecode(parse(json))
+  }
+}
+
+private[param] object DoubleParam {
+  /** Encodes a param value into JValue. */
+  def jValueEncode(value: Double): JValue = {
+    value match {
+      case _ if value.isNaN =>
+        JString("NaN")
+      case Double.NegativeInfinity =>
+        JString("-Inf")
+      case Double.PositiveInfinity =>
+        JString("Inf")
+      case _ =>
+        JDouble(value)
+    }
+  }
+
+  /** Decodes a param value from JValue. */
+  def jValueDecode(jValue: JValue): Double = {
+    jValue match {
+      case JString("NaN") =>
+        Double.NaN
+      case JString("-Inf") =>
+        Double.NegativeInfinity
+      case JString("Inf") =>
+        Double.PositiveInfinity
+      case JDouble(x) =>
+        x
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot decode $jValue to Double.")
+    }
+  }
 }
 
 /**
@@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
 
   /** Creates a param pair with the given value (for Java). */
   override def w(value: Int): ParamPair[Int] = super.w(value)
+
+  override def jsonEncode(value: Int): String = {
+    compact(render(JInt(value)))
+  }
+
+  override def jsonDecode(json: String): Int = {
+    implicit val formats = DefaultFormats
+    parse(json).extract[Int]
+  }
 }
 
 /**
@@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
 
   /** Creates a param pair with the given value (for Java). */
   override def w(value: Float): ParamPair[Float] = super.w(value)
+
+  override def jsonEncode(value: Float): String = {
+    compact(render(FloatParam.jValueEncode(value)))
+  }
+
+  override def jsonDecode(json: String): Float = {
+    FloatParam.jValueDecode(parse(json))
+  }
+}
+
+private object FloatParam {
+
+  /** Encodes a param value into JValue. */
+  def jValueEncode(value: Float): JValue = {
+    value match {
+      case _ if value.isNaN =>
+        JString("NaN")
+      case Float.NegativeInfinity =>
+        JString("-Inf")
+      case Float.PositiveInfinity =>
+        JString("Inf")
+      case _ =>
+        JDouble(value)
+    }
+  }
+
+  /** Decodes a param value from JValue. */
+  def jValueDecode(jValue: JValue): Float = {
+    jValue match {
+      case JString("NaN") =>
+        Float.NaN
+      case JString("-Inf") =>
+        Float.NegativeInfinity
+      case JString("Inf") =>
+        Float.PositiveInfinity
+      case JDouble(x) =>
+        x.toFloat
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot decode $jValue to Float.")
+    }
+  }
 }
 
 /**
@@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
 
   /** Creates a param pair with the given value (for Java). */
   override def w(value: Long): ParamPair[Long] = super.w(value)
+
+  override def jsonEncode(value: Long): String = {
+    compact(render(JInt(value)))
+  }
+
+  override def jsonDecode(json: String): Long = {
+    implicit val formats = DefaultFormats
+    parse(json).extract[Long]
+  }
 }
 
 /**
@@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
 
   /** Creates a param pair with the given value (for Java). */
   override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
+
+  override def jsonEncode(value: Boolean): String = {
+    compact(render(JBool(value)))
+  }
+
+  override def jsonDecode(json: String): Boolean = {
+    implicit val formats = DefaultFormats
+    parse(json).extract[Boolean]
+  }
 }
 
 /**
@@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
 
   /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
   def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
+
+  override def jsonEncode(value: Array[String]): String = {
+    import org.json4s.JsonDSL._
+    compact(render(value.toSeq))
+  }
+
+  override def jsonDecode(json: String): Array[String] = {
+    implicit val formats = DefaultFormats
+    parse(json).extract[Seq[String]].toArray
+  }
 }
 
 /**
@@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
   /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
   def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
     w(value.asScala.map(_.asInstanceOf[Double]).toArray)
+
+  override def jsonEncode(value: Array[Double]): String = {
+    import org.json4s.JsonDSL._
+    compact(render(value.toSeq.map(DoubleParam.jValueEncode)))
+  }
+
+  override def jsonDecode(json: String): Array[Double] = {
+    parse(json) match {
+      case JArray(values) =>
+        values.map(DoubleParam.jValueDecode).toArray
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].")
+    }
+  }
 }
 
 /**
@@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In
   /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
   def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
     w(value.asScala.map(_.asInstanceOf[Int]).toArray)
+
+  override def jsonEncode(value: Array[Int]): String = {
+    import org.json4s.JsonDSL._
+    compact(render(value.toSeq))
+  }
+
+  override def jsonDecode(json: String): Array[Int] = {
+    implicit val formats = DefaultFormats
+    parse(json).extract[Seq[Int]].toArray
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2b574f52/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index a2ea279..eeb03db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite
 
 class ParamsSuite extends SparkFunSuite {
 
+  test("json encode/decode") {
+    val dummy = new Params {
+      override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+      override val uid: String = "dummy"
+    }
+
+    { // BooleanParam
+      val param = new BooleanParam(dummy, "name", "doc")
+      for (value <- Seq(true, false)) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+
+    { // IntParam
+      val param = new IntParam(dummy, "name", "doc")
+      for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+
+    { // LongParam
+      val param = new LongParam(dummy, "name", "doc")
+      for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+
+    { // FloatParam
+      val param = new FloatParam(dummy, "name", "doc")
+      for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
+        Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
+        val json = param.jsonEncode(value)
+        val decoded = param.jsonDecode(json)
+        if (value.isNaN) {
+          assert(decoded.isNaN)
+        } else {
+          assert(decoded === value)
+        }
+      }
+    }
+
+    { // DoubleParam
+      val param = new DoubleParam(dummy, "name", "doc")
+      for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
+          Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
+        val json = param.jsonEncode(value)
+        val decoded = param.jsonDecode(json)
+        if (value.isNaN) {
+          assert(decoded.isNaN)
+        } else {
+          assert(decoded === value)
+        }
+      }
+    }
+
+    { // StringParam
+      val param = new Param[String](dummy, "name", "doc")
+      // Currently we do not support null.
+      for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+
+    { // IntArrayParam
+      val param = new IntArrayParam(dummy, "name", "doc")
+      val values: Seq[Array[Int]] = Seq(
+        Array(),
+        Array(1),
+        Array(Int.MinValue, 0, Int.MaxValue))
+      for (value <- values) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+
+    { // DoubleArrayParam
+      val param = new DoubleArrayParam(dummy, "name", "doc")
+      val values: Seq[Array[Double]] = Seq(
+         Array(),
+         Array(1.0),
+         Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+           Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
+      for (value <- values) {
+        val json = param.jsonEncode(value)
+        val decoded = param.jsonDecode(json)
+        assert(decoded.length === value.length)
+        decoded.zip(value).foreach { case (actual, expected) =>
+          if (expected.isNaN) {
+            assert(actual.isNaN)
+          } else {
+            assert(actual === expected)
+          }
+        }
+      }
+    }
+
+    { // StringArrayParam
+      val param = new StringArrayParam(dummy, "name", "doc")
+      val values: Seq[Array[String]] = Seq(
+        Array(),
+        Array(""),
+        Array("", "1", "abc", "quote\"", "newline\n"))
+      for (value <- values) {
+        val json = param.jsonEncode(value)
+        assert(param.jsonDecode(json) === value)
+      }
+    }
+  }
+
   test("param") {
     val solver = new TestParams()
     val uid = solver.uid


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org