You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/01/18 21:51:02 UTC

spark git commit: [SPARK-12346][ML] Missing attribute names in GLM for vector-type features

Repository: spark
Updated Branches:
  refs/heads/master 44fcf992a -> 5e492e9d5


[SPARK-12346][ML] Missing attribute names in GLM for vector-type features

Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names.

cc mengxr

Author: Eric Liang <ek...@databricks.com>

Closes #10323 from ericl/spark-12346.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e492e9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e492e9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e492e9d

Branch: refs/heads/master
Commit: 5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c
Parents: 44fcf99
Author: Eric Liang <ek...@databricks.com>
Authored: Mon Jan 18 12:50:58 2016 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Jan 18 12:50:58 2016 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/VectorAssembler.scala      |  6 ++--
 .../apache/spark/ml/feature/RFormulaSuite.scala | 38 ++++++++++++++++++++
 .../spark/ml/feature/VectorAssemblerSuite.scala |  4 +--
 3 files changed, 43 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5e492e9d/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 716bc63..7ff5ad1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String)
           val group = AttributeGroup.fromStructField(field)
           if (group.attributes.isDefined) {
             // If attributes are defined, copy them with updated names.
-            group.attributes.get.map { attr =>
+            group.attributes.get.zipWithIndex.map { case (attr, i) =>
               if (attr.name.isDefined) {
                 // TODO: Define a rigorous naming scheme.
                 attr.withName(c + "_" + attr.name.get)
               } else {
-                attr
+                attr.withName(c + "_" + i)
               }
             }
           } else {
             // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
             // from metadata, check the first row.
             val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
-            Array.fill(numAttrs)(NumericAttribute.defaultAttr)
+            Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
           }
         case otherType =>
           throw new SparkException(s"VectorAssembler does not support the $otherType type")

http://git-wip-us.apache.org/repos/asf/spark/blob/5e492e9d/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index dc20a5e..16e565d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(attrs === expectedAttrs)
   }
 
+  test("vector attribute generation") {
+    val formula = new RFormula().setFormula("id ~ vec")
+    val original = sqlContext.createDataFrame(
+      Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+    ).toDF("id", "vec")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new NumericAttribute(Some("vec_0"), Some(1)),
+        new NumericAttribute(Some("vec_1"), Some(2))))
+    assert(attrs === expectedAttrs)
+  }
+
+  test("vector attribute generation with unnamed input attrs") {
+    val formula = new RFormula().setFormula("id ~ vec2")
+    val base = sqlContext.createDataFrame(
+      Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+    ).toDF("id", "vec")
+    val metadata = new AttributeGroup(
+      "vec2",
+      Array[Attribute](
+        NumericAttribute.defaultAttr,
+        NumericAttribute.defaultAttr)).toMetadata
+    val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new NumericAttribute(Some("vec2_0"), Some(1)),
+        new NumericAttribute(Some("vec2_1"), Some(2))))
+    assert(attrs === expectedAttrs)
+  }
+
   test("numeric interaction") {
     val formula = new RFormula().setFormula("a ~ b:c:d")
     val original = sqlContext.createDataFrame(

http://git-wip-us.apache.org/repos/asf/spark/blob/5e492e9d/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index f7de7c1..dce994f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -111,8 +111,8 @@ class VectorAssemblerSuite
     assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
     val userSalaryOut = features.getAttr(4)
     assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
-    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
-    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
+    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
   }
 
   test("read/write") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org