You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2016/09/02 08:47:25 UTC

spark git commit: [SPARK-16883][SPARKR] SQL decimal type is not properly cast to number when collecting SparkDataFrame

Repository: spark
Updated Branches:
  refs/heads/master 2ab8dbdda -> 0f30cdedb


[SPARK-16883][SPARKR] SQL decimal type is not properly cast to number when collecting SparkDataFrame

## What changes were proposed in this pull request?

(Please fill in changes proposed in this fix)

registerTempTable(createDataFrame(iris), "iris")
str(collect(sql("select cast('1' as double) as x, cast('2' as decimal) as y  from iris limit 5")))

'data.frame':	5 obs. of  2 variables:
 $ x: num  1 1 1 1 1
 $ y:List of 5
  ..$ : num 2
  ..$ : num 2
  ..$ : num 2
  ..$ : num 2
  ..$ : num 2

The problem is that spark returns `decimal(10, 0)` col type, instead of `decimal`. Thus, `decimal(10, 0)` is not handled correctly. It should be handled as "double".

As discussed in JIRA thread, we can have two potential fixes:
1). Scala side fix to add a new case when writing the object back; However, I can't use spark.sql.types._ in Spark core due to dependency issues. I don't find a way of doing type case match;

2). SparkR side fix: Add a helper function to check special type like `"decimal(10, 0)"` and replace it with `double`, which is PRIMITIVE type. This special helper is generic for adding new types handling in the future.

I open this PR to discuss pros and cons of both approaches. If we want to do Scala side fix, we need to find a way to match the case of DecimalType and StructType in Spark Core.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)

Manual test:
> str(collect(sql("select cast('1' as double) as x, cast('2' as decimal) as y  from iris limit 5")))
'data.frame':	5 obs. of  2 variables:
 $ x: num  1 1 1 1 1
 $ y: num  2 2 2 2 2
R Unit tests

Author: wm624@hotmail.com <wm...@hotmail.com>

Closes #14613 from wangmiao1981/type.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f30cded
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f30cded
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f30cded

Branch: refs/heads/master
Commit: 0f30cdedbdb0d38e8c479efab6bb1c6c376206ff
Parents: 2ab8dbd
Author: wm624@hotmail.com <wm...@hotmail.com>
Authored: Fri Sep 2 01:47:17 2016 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Fri Sep 2 01:47:17 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/DataFrame.R                       | 13 ++++++++++++-
 R/pkg/R/types.R                           | 16 ++++++++++++++++
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 22 ++++++++++++++++++++++
 3 files changed, 50 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f30cded/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index e12b58e..a924502 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -397,7 +397,11 @@ setMethod("coltypes",
                 }
 
                 if (is.null(type)) {
-                  stop(paste("Unsupported data type: ", x))
+                  specialtype <- specialtypeshandle(x)
+                  if (is.null(specialtype)) {
+                    stop(paste("Unsupported data type: ", x))
+                  }
+                  type <- PRIMITIVE_TYPES[[specialtype]]
                 }
               }
               type
@@ -1063,6 +1067,13 @@ setMethod("collect",
                   df[[colIndex]] <- col
                 } else {
                   colType <- dtypes[[colIndex]][[2]]
+                  if (is.null(PRIMITIVE_TYPES[[colType]])) {
+                    specialtype <- specialtypeshandle(colType)
+                    if (!is.null(specialtype)) {
+                      colType <- specialtype
+                    }
+                  }
+
                   # Note that "binary" columns behave like complex types.
                   if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") {
                     vec <- do.call(c, col)

http://git-wip-us.apache.org/repos/asf/spark/blob/0f30cded/R/pkg/R/types.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R
index ad048b1..abca703 100644
--- a/R/pkg/R/types.R
+++ b/R/pkg/R/types.R
@@ -67,3 +67,19 @@ rToSQLTypes <- as.environment(list(
   "double" = "double",
   "character" = "string",
   "logical" = "boolean"))
+
+# Helper function of coverting decimal type. When backend returns column type in the
+# format of decimal(,) (e.g., decimal(10, 0)), this function coverts the column type
+# as double type. This function converts backend returned types that are not the key
+# of PRIMITIVE_TYPES, but should be treated as PRIMITIVE_TYPES.
+# @param A type returned from the JVM backend.
+# @return A type is the key of the PRIMITIVE_TYPES.
+specialtypeshandle <- function(type) {
+  returntype <- NULL
+  m <- regexec("^decimal(.+)$", type)
+  matchedStrings <- regmatches(type, m)
+  if (length(matchedStrings[[1]]) >= 2) {
+    returntype <- "double"
+  }
+  returntype
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0f30cded/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 8ff56eb..683a15c 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -526,6 +526,17 @@ test_that(
   expect_is(newdf, "SparkDataFrame")
   expect_equal(count(newdf), 1)
   dropTempView("table1")
+
+  createOrReplaceTempView(df, "dfView")
+  sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1"))
+  out <- capture.output(sqlCast)
+  expect_true(is.data.frame(sqlCast))
+  expect_equal(names(sqlCast)[1], "x")
+  expect_equal(nrow(sqlCast), 1)
+  expect_equal(ncol(sqlCast), 1)
+  expect_equal(out[1], "  x")
+  expect_equal(out[2], "1 2")
+  dropTempView("dfView")
 })
 
 test_that("test cache, uncache and clearCache", {
@@ -2089,6 +2100,9 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", {
   # Test primitive types
   DF <- createDataFrame(data, schema)
   expect_equal(coltypes(DF), c("integer", "logical", "POSIXct"))
+  createOrReplaceTempView(DF, "DFView")
+  sqlCast <- sql("select cast('2' as decimal) as x from DFView limit 1")
+  expect_equal(coltypes(sqlCast), "numeric")
 
   # Test complex types
   x <- createDataFrame(list(list(as.environment(
@@ -2132,6 +2146,14 @@ test_that("Method str()", {
                               "setosa\" \"setosa\" \"setosa\" \"setosa\""))
   expect_equal(out[7], " $ col         : logi TRUE TRUE TRUE TRUE TRUE TRUE")
 
+  createOrReplaceTempView(irisDF2, "irisView")
+
+  sqlCast <- sql("select cast('2' as decimal) as x from irisView limit 1")
+  castStr <- capture.output(str(sqlCast))
+  expect_equal(length(castStr), 2)
+  expect_equal(castStr[1], "'SparkDataFrame': 1 variables:")
+  expect_equal(castStr[2], " $ x: num 2")
+
   # A random dataset with many columns. This test is to check str limits
   # the number of columns. Therefore, it will suffice to check for the
   # number of returned rows


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org