You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/04/21 23:47:16 UTC

spark git commit: [SPARK-6994] Allow to fetch field values by name in sql.Row

Repository: spark
Updated Branches:
  refs/heads/master 04bf34e34 -> 2e8c6ca47


[SPARK-6994] Allow to fetch field values by name in sql.Row

It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index.

This tries to solve this issue.

Author: vidmantas zemleris <vi...@vinted.com>

Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits:

6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row
9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2e8c6ca4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2e8c6ca4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2e8c6ca4

Branch: refs/heads/master
Commit: 2e8c6ca47df14681c1110f0736234ce76a3eca9b
Parents: 04bf34e
Author: vidmantas zemleris <vi...@vinted.com>
Authored: Tue Apr 21 14:47:09 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Apr 21 14:47:09 2015 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Row.scala   | 32 +++++++++
 .../spark/sql/catalyst/expressions/rows.scala   |  2 +
 .../org/apache/spark/sql/types/dataTypes.scala  |  9 +++
 .../scala/org/apache/spark/sql/RowTest.scala    | 71 ++++++++++++++++++++
 .../apache/spark/sql/types/DataTypeSuite.scala  | 13 ++++
 .../scala/org/apache/spark/sql/RowSuite.scala   | 10 +++
 6 files changed, 137 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index ac8a782..4190b7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -306,6 +306,38 @@ trait Row extends Serializable {
    */
   def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
 
+  /**
+   * Returns the value of a given fieldName.
+   *
+   * @throws UnsupportedOperationException when schema is not defined.
+   * @throws IllegalArgumentException when fieldName do not exist.
+   * @throws ClassCastException when data type does not match.
+   */
+  def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
+
+  /**
+   * Returns the index of a given field name.
+   *
+   * @throws UnsupportedOperationException when schema is not defined.
+   * @throws IllegalArgumentException when fieldName do not exist.
+   */
+  def fieldIndex(name: String): Int = {
+    throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
+  }
+
+  /**
+   * Returns a Map(name -> value) for the requested fieldNames
+   *
+   * @throws UnsupportedOperationException when schema is not defined.
+   * @throws IllegalArgumentException when fieldName do not exist.
+   * @throws ClassCastException when data type does not match.
+   */
+  def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
+    fieldNames.map { name =>
+      name -> getAs[T](name)
+    }.toMap
+  }
+
   override def toString(): String = s"[${this.mkString(",")}]"
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index b6ec7d3..9813734 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
 
   /** No-arg constructor for serialization. */
   protected def this() = this(null, null)
+
+  override def fieldIndex(name: String): Int = schema.fieldIndex(name)
 }
 
 class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {

http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index a108413..7cd7bd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
 
   private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
   private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+  private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
 
   /**
    * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
@@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
     StructType(fields.filter(f => names.contains(f.name)))
   }
 
+  /**
+   * Returns index of a given field
+   */
+  def fieldIndex(name: String): Int = {
+    nameToIndex.getOrElse(name,
+      throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+  }
+
   protected[sql] def toAttributes: Seq[AttributeReference] =
     map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
new file mode 100644
index 0000000..bbb9739
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
+import org.apache.spark.sql.types._
+import org.scalatest.{Matchers, FunSpec}
+
+class RowTest extends FunSpec with Matchers {
+
+  val schema = StructType(
+    StructField("col1", StringType) ::
+    StructField("col2", StringType) ::
+    StructField("col3", IntegerType) :: Nil)
+  val values = Array("value1", "value2", 1)
+
+  val sampleRow: Row = new GenericRowWithSchema(values, schema)
+  val noSchemaRow: Row = new GenericRow(values)
+
+  describe("Row (without schema)") {
+    it("throws an exception when accessing by fieldName") {
+      intercept[UnsupportedOperationException] {
+        noSchemaRow.fieldIndex("col1")
+      }
+      intercept[UnsupportedOperationException] {
+        noSchemaRow.getAs("col1")
+      }
+    }
+  }
+
+  describe("Row (with schema)") {
+    it("fieldIndex(name) returns field index") {
+      sampleRow.fieldIndex("col1") shouldBe 0
+      sampleRow.fieldIndex("col3") shouldBe 2
+    }
+
+    it("getAs[T] retrieves a value by fieldname") {
+      sampleRow.getAs[String]("col1") shouldBe "value1"
+      sampleRow.getAs[Int]("col3") shouldBe 1
+    }
+
+    it("Accessing non existent field throws an exception") {
+      intercept[IllegalArgumentException] {
+        sampleRow.getAs[String]("non_existent")
+      }
+    }
+
+    it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") {
+      val expected = Map(
+        "col1" -> "value1",
+        "col2" -> "value2"
+      )
+      sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index a1341ea..d797510 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
     }
   }
 
+  test("extract field index from a StructType") {
+    val struct = StructType(
+      StructField("a", LongType) ::
+      StructField("b", FloatType) :: Nil)
+
+    assert(struct.fieldIndex("a") === 0)
+    assert(struct.fieldIndex("b") === 1)
+
+    intercept[IllegalArgumentException] {
+      struct.fieldIndex("non_existent")
+    }
+  }
+
   def checkDataTypeJsonRepr(dataType: DataType): Unit = {
     test(s"JSON - $dataType") {
       assert(DataType.fromJson(dataType.json) === dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/2e8c6ca4/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index bf6cf13..fb3ba4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -62,4 +62,14 @@ class RowSuite extends FunSuite {
     val de = instance.deserialize(ser).asInstanceOf[Row]
     assert(de === row)
   }
+
+  test("get values by field name on Row created via .toDF") {
+    val row = Seq((1, Seq(1))).toDF("a", "b").first()
+    assert(row.getAs[Int]("a") === 1)
+    assert(row.getAs[Seq[Int]]("b") === Seq(1))
+
+    intercept[IllegalArgumentException]{
+      row.getAs[Int]("c")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org