You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/07/16 05:33:10 UTC

spark git commit: [SPARK-8774] [ML] Add R model formula with basic support as a transformer

Repository: spark
Updated Branches:
  refs/heads/master b0645195d -> 6960a7938


[SPARK-8774] [ML] Add R model formula with basic support as a transformer

This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now.

cc mengxr

Author: Eric Liang <ek...@databricks.com>

Closes #7381 from ericl/spark-8774-1 and squashes the following commits:

d1959d2 [Eric Liang] clarify comment
2db68aa [Eric Liang] second round of comments
dc3c943 [Eric Liang] address comments
5765ec6 [Eric Liang] fix style checks
1f361b0 [Eric Liang] doc
fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6960a793
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6960a793
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6960a793

Branch: refs/heads/master
Commit: 6960a7938c61cc07f181ca85e0d8152ceeb453d9
Parents: b064519
Author: Eric Liang <ek...@databricks.com>
Authored: Wed Jul 15 20:33:06 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jul 15 20:33:06 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala  | 151 +++++++++++++++++++
 .../spark/ml/feature/VectorAssembler.scala      |   2 +-
 .../spark/ml/feature/RFormulaParserSuite.scala  |  34 +++++
 .../apache/spark/ml/feature/RFormulaSuite.scala |  93 ++++++++++++
 4 files changed, 279 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6960a793/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
new file mode 100644
index 0000000..d9a36bd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Implements the transforms required for fitting a dataset against an R model formula. Currently
+ * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
+ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+@Experimental
+class RFormula(override val uid: String)
+  extends Transformer with HasFeaturesCol with HasLabelCol {
+
+  def this() = this(Identifiable.randomUID("rFormula"))
+
+  /**
+   * R formula parameter. The formula is provided in string form.
+   * @group setParam
+   */
+  val formula: Param[String] = new Param(this, "formula", "R model formula")
+
+  private var parsedFormula: Option[ParsedRFormula] = None
+
+  /**
+   * Sets the formula to use for this transformer. Must be called before use.
+   * @group setParam
+   * @param value an R formula in string form (e.g. "y ~ x + z")
+   */
+  def setFormula(value: String): this.type = {
+    parsedFormula = Some(RFormulaParser.parse(value))
+    set(formula, value)
+    this
+  }
+
+  /** @group getParam */
+  def getFormula: String = $(formula)
+
+  /** @group getParam */
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group getParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  override def transformSchema(schema: StructType): StructType = {
+    checkCanTransform(schema)
+    val withFeatures = transformFeatures.transformSchema(schema)
+    if (hasLabelCol(schema)) {
+      withFeatures
+    } else {
+      val nullable = schema(parsedFormula.get.label).dataType match {
+        case _: NumericType | BooleanType => false
+        case _ => true
+      }
+      StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
+    }
+  }
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    checkCanTransform(dataset.schema)
+    transformLabel(transformFeatures.transform(dataset))
+  }
+
+  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+  override def toString: String = s"RFormula(${get(formula)})"
+
+  private def transformLabel(dataset: DataFrame): DataFrame = {
+    if (hasLabelCol(dataset.schema)) {
+      dataset
+    } else {
+      val labelName = parsedFormula.get.label
+      dataset.schema(labelName).dataType match {
+        case _: NumericType | BooleanType =>
+          dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
+        // TODO(ekl) add support for string-type labels
+        case other =>
+          throw new IllegalArgumentException("Unsupported type for label: " + other)
+      }
+    }
+  }
+
+  private def transformFeatures: Transformer = {
+    // TODO(ekl) add support for non-numeric features and feature interactions
+    new VectorAssembler(uid)
+      .setInputCols(parsedFormula.get.terms.toArray)
+      .setOutputCol($(featuresCol))
+  }
+
+  private def checkCanTransform(schema: StructType) {
+    require(parsedFormula.isDefined, "Must call setFormula() first.")
+    val columnNames = schema.map(_.name)
+    require(!columnNames.contains($(featuresCol)), "Features column already exists.")
+    require(
+      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
+      "Label column already exists and is not of type DoubleType.")
+  }
+
+  private def hasLabelCol(schema: StructType): Boolean = {
+    schema.map(_.name).contains($(labelCol))
+  }
+}
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+  def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
+
+  def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
+
+  def formula: Parser[ParsedRFormula] =
+    (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+  def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+    case Success(result, _) => result
+    case failure: NoSuccess => throw new IllegalArgumentException(
+      "Could not parse formula: " + value)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6960a793/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 9f83c2e..086917f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
     if (schema.fieldNames.contains(outputColName)) {
       throw new IllegalArgumentException(s"Output column $outputColName already exists.")
     }
-    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
+    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
   }
 
   override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/6960a793/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
new file mode 100644
index 0000000..c8d065f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+
+class RFormulaParserSuite extends SparkFunSuite {
+  private def checkParse(formula: String, label: String, terms: Seq[String]) {
+    val parsed = RFormulaParser.parse(formula)
+    assert(parsed.label == label)
+    assert(parsed.terms == terms)
+  }
+
+  test("parse simple formulas") {
+    checkParse("y ~ x", "y", Seq("x"))
+    checkParse("y ~   ._foo  ", "y", Seq("._foo"))
+    checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6960a793/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
new file mode 100644
index 0000000..fa8611b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+  test("params") {
+    ParamsSuite.checkParams(new RFormula())
+  }
+
+  test("transform numeric data") {
+    val formula = new RFormula().setFormula("id ~ v1 + v2")
+    val original = sqlContext.createDataFrame(
+      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+    val result = formula.transform(original)
+    val resultSchema = formula.transformSchema(original.schema)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
+        (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+      ).toDF("id", "v1", "v2", "features", "label")
+    // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
+    assert(result.schema.toString == resultSchema.toString)
+    assert(resultSchema == expected.schema)
+    assert(result.collect().toSeq == expected.collect().toSeq)
+  }
+
+  test("features column already exists") {
+    val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
+    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+    intercept[IllegalArgumentException] {
+      formula.transformSchema(original.schema)
+    }
+    intercept[IllegalArgumentException] {
+      formula.transform(original)
+    }
+  }
+
+  test("label column already exists") {
+    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+    val resultSchema = formula.transformSchema(original.schema)
+    assert(resultSchema.length == 3)
+    assert(resultSchema.toString == formula.transform(original).schema.toString)
+  }
+
+  test("label column already exists but is not double type") {
+    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+    val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+    intercept[IllegalArgumentException] {
+      formula.transformSchema(original.schema)
+    }
+    intercept[IllegalArgumentException] {
+      formula.transform(original)
+    }
+  }
+
+// TODO(ekl) enable after we implement string label support
+//  test("transform string label") {
+//    val formula = new RFormula().setFormula("name ~ id")
+//    val original = sqlContext.createDataFrame(
+//      Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
+//    val result = formula.transform(original)
+//    val resultSchema = formula.transformSchema(original.schema)
+//    val expected = sqlContext.createDataFrame(
+//      Seq(
+//        (1, "foo", Vectors.dense(Array(1.0)), 1.0),
+//        (2, "bar", Vectors.dense(Array(2.0)), 0.0),
+//        (3, "bar", Vectors.dense(Array(3.0)), 0.0))
+//      ).toDF("id", "name", "features", "label")
+//    assert(result.schema.toString == resultSchema.toString)
+//    assert(result.collect().toSeq == expected.collect().toSeq)
+//  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org