You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/15 18:48:40 UTC
spark git commit: [SPARK-8840] [SPARKR] Add float coercion on SparkR
Repository: spark
Updated Branches:
refs/heads/master 20bb10f86 -> 6f6902597
[SPARK-8840] [SPARKR] Add float coercion on SparkR
JIRA: https://issues.apache.org/jira/browse/SPARK-8840
Currently the type coercion rules don't include float type. This PR simply adds it.
Author: Liang-Chi Hsieh <vi...@appier.com>
Closes #7280 from viirya/add_r_float_coercion and squashes the following commits:
c86dc0e [Liang-Chi Hsieh] For comments.
dbf0c1b [Liang-Chi Hsieh] Implicitly convert Double to Float based on provided schema.
733015a [Liang-Chi Hsieh] Add test case for DataFrame with float type.
30c2a40 [Liang-Chi Hsieh] Update test case.
52b5294 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_r_float_coercion
6f9159d [Liang-Chi Hsieh] Add another test case.
8db3244 [Liang-Chi Hsieh] schema also needs to support float. add test case.
0dcc992 [Liang-Chi Hsieh] Add float coercion on SparkR.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6f690259
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6f690259
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6f690259
Branch: refs/heads/master
Commit: 6f6902597d5d687049c103bc0cf6da30919b92d8
Parents: 20bb10f
Author: Liang-Chi Hsieh <vi...@appier.com>
Authored: Wed Jul 15 09:48:33 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 15 09:48:33 2015 -0700
----------------------------------------------------------------------
R/pkg/R/deserialize.R | 1 +
R/pkg/R/schema.R | 1 +
R/pkg/inst/tests/test_sparkSQL.R | 26 ++++++++++++++++++++
.../scala/org/apache/spark/api/r/SerDe.scala | 4 +++
.../org/apache/spark/sql/api/r/SQLUtils.scala | 15 ++++++++---
5 files changed, 44 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6f690259/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index d961bbc..7d1f6b0 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -23,6 +23,7 @@
# Int -> integer
# String -> character
# Boolean -> logical
+# Float -> double
# Double -> double
# Long -> double
# Array[Byte] -> raw
http://git-wip-us.apache.org/repos/asf/spark/blob/6f690259/R/pkg/R/schema.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 15e2bdb..06df430 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) {
}
options <- c("byte",
"integer",
+ "float",
"double",
"numeric",
"character",
http://git-wip-us.apache.org/repos/asf/spark/blob/6f690259/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index b0ea388..76f74f8 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", {
expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ df <- jsonFile(sqlContext, jsonPathNa)
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ }, error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
+ insertInto(df, "people")
+ expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16))
+ expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5))
+
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ df2 <- createDataFrame(sqlContext, df.toRDD, schema)
+ expect_equal(columns(df2), c("name", "age", "height"))
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
+
+ localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
+ df <- createDataFrame(sqlContext, localDF, schema)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
+ expect_equal(columns(df), c("name", "age", "height"))
+ expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10))
})
test_that("convert NAs to null type in DataFrames", {
http://git-wip-us.apache.org/repos/asf/spark/blob/6f690259/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 56adc85..d5b4260 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -179,6 +179,7 @@ private[spark] object SerDe {
// Int -> integer
// String -> character
// Boolean -> logical
+ // Float -> double
// Double -> double
// Long -> double
// Array[Byte] -> raw
@@ -215,6 +216,9 @@ private[spark] object SerDe {
case "long" | "java.lang.Long" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Long].toDouble)
+ case "float" | "java.lang.Float" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Float].toDouble)
case "double" | "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
http://git-wip-us.apache.org/repos/asf/spark/blob/6f690259/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 43b62f0..92861ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -47,6 +47,7 @@ private[r] object SQLUtils {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
case "integer" => org.apache.spark.sql.types.IntegerType
+ case "float" => org.apache.spark.sql.types.FloatType
case "double" => org.apache.spark.sql.types.DoubleType
case "numeric" => org.apache.spark.sql.types.DoubleType
case "character" => org.apache.spark.sql.types.StringType
@@ -68,7 +69,7 @@ private[r] object SQLUtils {
def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
- val rowRDD = rdd.map(bytesToRow)
+ val rowRDD = rdd.map(bytesToRow(_, schema))
sqlContext.createDataFrame(rowRDD, schema)
}
@@ -76,12 +77,20 @@ private[r] object SQLUtils {
df.map(r => rowToRBytes(r))
}
- private[this] def bytesToRow(bytes: Array[Byte]): Row = {
+ private[this] def doConversion(data: Object, dataType: DataType): Object = {
+ data match {
+ case d: java.lang.Double if dataType == FloatType =>
+ new java.lang.Float(d)
+ case _ => data
+ }
+ }
+
+ private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
Row.fromSeq((0 until num).map { i =>
- SerDe.readObject(dis)
+ doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
}.toSeq)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org