You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2015/10/13 19:02:26 UTC

spark git commit: [SPARK-10051] [SPARKR] Support collecting data of StructType in DataFrame

Repository: spark
Updated Branches:
  refs/heads/master d0cc79ccd -> 5e3868ba1


[SPARK-10051] [SPARKR] Support collecting data of StructType in DataFrame

Two points in this PR:

1.    Originally thought was that a named R list is assumed to be a struct in SerDe. But this is problematic because some R functions will implicitly generate named lists that are not intended to be a struct when transferred by SerDe. So SerDe clients have to explicitly mark a names list as struct by changing its class from "list" to "struct".

2.    SerDe is in the Spark Core module, and data of StructType is represented as GenricRow which is defined in Spark SQL module. SerDe can't import GenricRow as in maven build  Spark SQL module depends on Spark Core module. So this PR adds a registration hook in SerDe to allow SQLUtils in Spark SQL module to register its functions for serialization and deserialization of StructType.

Author: Sun Rui <ru...@intel.com>

Closes #8794 from sun-rui/SPARK-10051.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e3868ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e3868ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e3868ba

Branch: refs/heads/master
Commit: 5e3868ba139f5f0b3a33361c6b884594a3ab6421
Parents: d0cc79c
Author: Sun Rui <ru...@intel.com>
Authored: Tue Oct 13 10:02:21 2015 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Tue Oct 13 10:02:21 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/SQLContext.R                            | 22 +++---
 R/pkg/R/deserialize.R                           | 10 +++
 R/pkg/R/schema.R                                | 28 +++++++-
 R/pkg/R/serialize.R                             | 43 ++++++++----
 R/pkg/R/sparkR.R                                |  4 +-
 R/pkg/R/utils.R                                 | 17 +++++
 R/pkg/inst/tests/test_sparkSQL.R                | 51 ++++++++------
 .../scala/org/apache/spark/api/r/SerDe.scala    | 71 +++++++++++++++-----
 .../org/apache/spark/sql/api/r/SQLUtils.scala   | 47 +++++++++++--
 9 files changed, 224 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/SQLContext.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 1c58fd9..66c7e30 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -32,6 +32,7 @@ infer_type <- function(x) {
                  numeric = "double",
                  raw = "binary",
                  list = "array",
+                 struct = "struct",
                  environment = "map",
                  Date = "date",
                  POSIXlt = "timestamp",
@@ -44,17 +45,18 @@ infer_type <- function(x) {
     paste0("map<string,", infer_type(get(key, x)), ">")
   } else if (type == "array") {
     stopifnot(length(x) > 0)
+
+    paste0("array<", infer_type(x[[1]]), ">")
+  } else if (type == "struct") {
+    stopifnot(length(x) > 0)
     names <- names(x)
-    if (is.null(names)) {
-      paste0("array<", infer_type(x[[1]]), ">")
-    } else {
-      # StructType
-      types <- lapply(x, infer_type)
-      fields <- lapply(1:length(x), function(i) {
-        structField(names[[i]], types[[i]], TRUE)
-      })
-      do.call(structType, fields)
-    }
+    stopifnot(!is.null(names))
+
+    type <- lapply(seq_along(x), function(i) {
+      paste0(names[[i]], ":", infer_type(x[[i]]), ",")
+    })
+    type <- Reduce(paste0, type)
+    type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">")
   } else if (length(x) > 1) {
     paste0("array<", infer_type(x[[1]]), ">")
   } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index ce88d0b..f7e56e4 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -51,6 +51,7 @@ readTypedObject <- function(con, type) {
     "a" = readArray(con),
     "l" = readList(con),
     "e" = readEnv(con),
+    "s" = readStruct(con),
     "n" = NULL,
     "j" = getJobj(readString(con)),
     stop(paste("Unsupported type for deserialization", type)))
@@ -135,6 +136,15 @@ readEnv <- function(con) {
   env
 }
 
+# Read a field of StructType from DataFrame
+# into a named list in R whose class is "struct"
+readStruct <- function(con) {
+  names <- readObject(con)
+  fields <- readObject(con)
+  names(fields) <- names
+  listToStruct(fields)
+}
+
 readRaw <- function(con) {
   dataLen <- readInt(con)
   readBin(con, raw(), as.integer(dataLen), endian = "big")

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/schema.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 8df1563..6f0e9a9 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -136,7 +136,7 @@ checkType <- function(type) {
     switch (firstChar,
             a = {
               # Array type
-              m <- regexec("^array<(.*)>$", type)
+              m <- regexec("^array<(.+)>$", type)
               matchedStrings <- regmatches(type, m)
               if (length(matchedStrings[[1]]) >= 2) {
                 elemType <- matchedStrings[[1]][2]
@@ -146,7 +146,7 @@ checkType <- function(type) {
             },
             m = {
               # Map type
-              m <- regexec("^map<(.*),(.*)>$", type)
+              m <- regexec("^map<(.+),(.+)>$", type)
               matchedStrings <- regmatches(type, m)
               if (length(matchedStrings[[1]]) >= 3) {
                 keyType <- matchedStrings[[1]][2]
@@ -157,6 +157,30 @@ checkType <- function(type) {
                 checkType(valueType)
                 return()
               }
+            },
+            s = {
+              # Struct type
+              m <- regexec("^struct<(.+)>$", type)
+              matchedStrings <- regmatches(type, m)
+              if (length(matchedStrings[[1]]) >= 2) {
+                fieldsString <- matchedStrings[[1]][2]
+                # strsplit does not return the final empty string, so check if
+                # the final char is ","
+                if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") {
+                  fields <- strsplit(fieldsString, ",")[[1]]
+                  for (field in fields) {
+                    m <- regexec("^(.+):(.+)$", field)
+                    matchedStrings <- regmatches(field, m)
+                    if (length(matchedStrings[[1]]) >= 3) {
+                      fieldType <- matchedStrings[[1]][3]
+                      checkType(fieldType)
+                    } else {
+                      break
+                    }
+                  }
+                  return()
+                }
+              }
             })
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/serialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 91e6b3e..17082b4 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -32,6 +32,21 @@
 # environment -> Map[String, T], where T is a native type
 # jobj -> Object, where jobj is an object created in the backend
 
+getSerdeType <- function(object) {
+  type <- class(object)[[1]]
+  if (type != "list") {
+    type
+  } else {
+    # Check if all elements are of same type
+    elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) }))
+    if (length(elemType) <= 1) {
+      "array"
+    } else {
+      "list"
+    }
+  }
+}
+
 writeObject <- function(con, object, writeType = TRUE) {
   # NOTE: In R vectors have same type as objects. So we don't support
   # passing in vectors as arrays and instead require arrays to be passed
@@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) {
       type <- "NULL"
     }
   }
+
+  serdeType <- getSerdeType(object)
   if (writeType) {
-    writeType(con, type)
+    writeType(con, serdeType)
   }
-  switch(type,
+  switch(serdeType,
          NULL = writeVoid(con),
          integer = writeInt(con, object),
          character = writeString(con, object),
@@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) {
          double = writeDouble(con, object),
          numeric = writeDouble(con, object),
          raw = writeRaw(con, object),
+         array = writeArray(con, object),
          list = writeList(con, object),
+         struct = writeList(con, object),
          jobj = writeJobj(con, object),
          environment = writeEnv(con, object),
          Date = writeDate(con, object),
@@ -110,7 +129,7 @@ writeRowSerialize <- function(outputCon, rows) {
 serializeRow <- function(row) {
   rawObj <- rawConnection(raw(0), "wb")
   on.exit(close(rawObj))
-  writeGenericList(rawObj, row)
+  writeList(rawObj, row)
   rawConnectionValue(rawObj)
 }
 
@@ -128,7 +147,9 @@ writeType <- function(con, class) {
                  double = "d",
                  numeric = "d",
                  raw = "r",
+                 array = "a",
                  list = "l",
+                 struct = "s",
                  jobj = "j",
                  environment = "e",
                  Date = "D",
@@ -139,15 +160,13 @@ writeType <- function(con, class) {
 }
 
 # Used to pass arrays where all the elements are of the same type
-writeList <- function(con, arr) {
-  # All elements should be of same type
-  elemType <- unique(sapply(arr, function(elem) { class(elem) }))
-  stopifnot(length(elemType) <= 1)
-
+writeArray <- function(con, arr) {
   # TODO: Empty lists are given type "character" right now.
   # This may not work if the Java side expects array of any other type.
-  if (length(elemType) == 0) {
+  if (length(arr) == 0) {
     elemType <- class("somestring")
+  } else {
+    elemType <- getSerdeType(arr[[1]])
   }
 
   writeType(con, elemType)
@@ -161,7 +180,7 @@ writeList <- function(con, arr) {
 }
 
 # Used to pass arrays where the elements can be of different types
-writeGenericList <- function(con, list) {
+writeList <- function(con, list) {
   writeInt(con, length(list))
   for (elem in list) {
     writeObject(con, elem)
@@ -174,9 +193,9 @@ writeEnv <- function(con, env) {
 
   writeInt(con, len)
   if (len > 0) {
-    writeList(con, as.list(ls(env)))
+    writeArray(con, as.list(ls(env)))
     vals <- lapply(ls(env), function(x) { env[[x]] })
-    writeGenericList(con, as.list(vals))
+    writeList(con, as.list(vals))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 3c57a44..cc47110 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -178,7 +178,7 @@ sparkR.init <- function(
   }
 
   nonEmptyJars <- Filter(function(x) { x != "" }, jars)
-  localJarPaths <- sapply(nonEmptyJars,
+  localJarPaths <- lapply(nonEmptyJars,
                           function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
 
   # Set the start time to identify jobjs
@@ -193,7 +193,7 @@ sparkR.init <- function(
       master,
       appName,
       as.character(sparkHome),
-      as.list(localJarPaths),
+      localJarPaths,
       sparkEnvirMap,
       sparkExecutorEnvMap),
     envir = .sparkREnv

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 69a2bc7..94f16c7 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -588,3 +588,20 @@ mergePartitions <- function(rdd, zip) {
 
   PipelinedRDD(rdd, partitionFunc)
 }
+
+# Convert a named list to struct so that
+# SerDe won't confuse between a normal named list and struct
+listToStruct <- function(list) {
+  stopifnot(class(list) == "list")
+  stopifnot(!is.null(names(list)))
+  class(list) <- "struct"
+  list
+}
+
+# Convert a struct to a named list
+structToList <- function(struct) {
+  stopifnot(class(list) == "struct")
+
+  class(struct) <- "list"
+  struct
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 3a04edb..af6efa4 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -66,10 +66,7 @@ test_that("infer types and check types", {
   expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
   expect_equal(infer_type(c(1L, 2L)), "array<integer>")
   expect_equal(infer_type(list(1L, 2L)), "array<integer>")
-  testStruct <- infer_type(list(a = 1L, b = "2"))
-  expect_equal(class(testStruct), "structType")
-  checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
-  checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE)
+  expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct<a:integer,b:string>")
   e <- new.env()
   assign("a", 1L, envir = e)
   expect_equal(infer_type(e), "map<string,integer>")
@@ -242,38 +239,36 @@ test_that("create DataFrame with different data types", {
   expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
 })
 
-test_that("create DataFrame with nested array and map", {
-#  e <- new.env()
-#  assign("n", 3L, envir = e)
-#  l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
-#  df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d"))
-#  expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>"),
-#                                c("c", "map<string,int>"), c("d", "struct<a:string,b:int>")))
-#  expect_equal(count(df), 1)
-#  ldf <- collect(df)
-#  expect_equal(ldf[1,], l[[1]])
-
-  #  ArrayType and MapType
+test_that("create DataFrame with complex types", {
   e <- new.env()
   assign("n", 3L, envir = e)
 
-  l <- list(as.list(1:10), list("a", "b"), e)
-  df <- createDataFrame(sqlContext, list(l), c("a", "b", "c"))
+  s <- listToStruct(list(a = "aa", b = 3L))
+
+  l <- list(as.list(1:10), list("a", "b"), e, s)
+  df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d"))
   expect_equal(dtypes(df), list(c("a", "array<int>"),
                                 c("b", "array<string>"),
-                                c("c", "map<string,int>")))
+                                c("c", "map<string,int>"),
+                                c("d", "struct<a:string,b:int>")))
   expect_equal(count(df), 1)
   ldf <- collect(df)
-  expect_equal(names(ldf), c("a", "b", "c"))
+  expect_equal(names(ldf), c("a", "b", "c", "d"))
   expect_equal(ldf[1, 1][[1]], l[[1]])
   expect_equal(ldf[1, 2][[1]], l[[2]])
+
   e <- ldf$c[[1]]
   expect_equal(class(e), "environment")
   expect_equal(ls(e), "n")
   expect_equal(e$n, 3L)
+
+  s <- ldf$d[[1]]
+  expect_equal(class(s), "struct")
+  expect_equal(s$a, "aa")
+  expect_equal(s$b, 3L)
 })
 
-# For test map type in DataFrame
+# For test map type and struct type in DataFrame
 mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}",
                       "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}",
                       "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}")
@@ -308,7 +303,19 @@ test_that("Collect DataFrame with complex types", {
   expect_equal(bob$age, 16)
   expect_equal(bob$height, 176.5)
 
-  # TODO: tests for StructType after it is supported
+  # StructType
+  df <- jsonFile(sqlContext, mapTypeJsonPath)
+  expect_equal(dtypes(df), list(c("info", "struct<age:bigint,height:double>"),
+                                c("name", "string")))
+  ldf <- collect(df)
+  expect_equal(nrow(ldf), 3)
+  expect_equal(ncol(ldf), 2)
+  expect_equal(names(ldf), c("info", "name"))
+  expect_equal(ldf$name, c("Bob", "Alice", "David"))
+  bob <- ldf$info[[1]]
+  expect_equal(class(bob), "struct")
+  expect_equal(bob$age, 16)
+  expect_equal(bob$height, 176.5)
 })
 
 test_that("jsonFile() on a local file returns a DataFrame", {

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 0c78613..da126ba 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -27,6 +27,14 @@ import scala.collection.mutable.WrappedArray
  * Utility functions to serialize, deserialize objects to / from R
  */
 private[spark] object SerDe {
+  type ReadObject = (DataInputStream, Char) => Object
+  type WriteObject = (DataOutputStream, Object) => Boolean
+
+  var sqlSerDe: (ReadObject, WriteObject) = _
+
+  def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = {
+    this.sqlSerDe = sqlSerDe
+  }
 
   // Type mapping from R to Java
   //
@@ -63,11 +71,22 @@ private[spark] object SerDe {
       case 'c' => readString(dis)
       case 'e' => readMap(dis)
       case 'r' => readBytes(dis)
+      case 'a' => readArray(dis)
       case 'l' => readList(dis)
       case 'D' => readDate(dis)
       case 't' => readTime(dis)
       case 'j' => JVMObjectTracker.getObject(readString(dis))
-      case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+      case _ =>
+        if (sqlSerDe == null || sqlSerDe._1 == null) {
+          throw new IllegalArgumentException (s"Invalid type $dataType")
+        } else {
+          val obj = (sqlSerDe._1)(dis, dataType)
+          if (obj == null) {
+            throw new IllegalArgumentException (s"Invalid type $dataType")
+          } else {
+            obj
+          }
+        }
     }
   }
 
@@ -141,7 +160,8 @@ private[spark] object SerDe {
     (0 until len).map(_ => readString(in)).toArray
   }
 
-  def readList(dis: DataInputStream): Array[_] = {
+  // All elements of an array must be of the same type
+  def readArray(dis: DataInputStream): Array[_] = {
     val arrType = readObjectType(dis)
     arrType match {
       case 'i' => readIntArr(dis)
@@ -150,26 +170,43 @@ private[spark] object SerDe {
       case 'b' => readBooleanArr(dis)
       case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
       case 'r' => readBytesArr(dis)
-      case 'l' => {
+      case 'a' =>
+        val len = readInt(dis)
+        (0 until len).map(_ => readArray(dis)).toArray
+      case 'l' =>
         val len = readInt(dis)
         (0 until len).map(_ => readList(dis)).toArray
-      }
-      case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+      case _ =>
+        if (sqlSerDe == null || sqlSerDe._1 == null) {
+          throw new IllegalArgumentException (s"Invalid array type $arrType")
+        } else {
+          val len = readInt(dis)
+          (0 until len).map { _ =>
+            val obj = (sqlSerDe._1)(dis, arrType)
+            if (obj == null) {
+              throw new IllegalArgumentException (s"Invalid array type $arrType")
+            } else {
+              obj
+            }
+          }.toArray
+        }
     }
   }
 
+  // Each element of a list can be of different type. They are all represented
+  // as Object on JVM side
+  def readList(dis: DataInputStream): Array[Object] = {
+    val len = readInt(dis)
+    (0 until len).map(_ => readObject(dis)).toArray
+  }
+
   def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
     val len = readInt(in)
     if (len > 0) {
-      val keysType = readObjectType(in)
-      val keysLen = readInt(in)
-      val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
-
-      val valuesLen = readInt(in)
-      val values = (0 until valuesLen).map(_ => {
-        val valueType = readObjectType(in)
-        readTypedObject(in, valueType)
-      })
+      // Keys is an array of String
+      val keys = readArray(in).asInstanceOf[Array[Object]]
+      val values = readList(in)
+
       keys.zip(values).toMap.asJava
     } else {
       new java.util.HashMap[Object, Object]()
@@ -338,8 +375,10 @@ private[spark] object SerDe {
           }
 
         case _ =>
-          writeType(dos, "jobj")
-          writeJObj(dos, value)
+          if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) {
+            writeType(dos, "jobj")
+            writeJObj(dos, value)
+          }
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5e3868ba/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index f45d119..b0120a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -22,13 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.r.SerDe
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
 
 import scala.util.matching.Regex
 
 private[r] object SQLUtils {
+  SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
+
   def createSQLContext(jsc: JavaSparkContext): SQLContext = {
     new SQLContext(jsc)
   }
@@ -61,15 +63,27 @@ private[r] object SQLUtils {
       case "boolean" => org.apache.spark.sql.types.BooleanType
       case "timestamp" => org.apache.spark.sql.types.TimestampType
       case "date" => org.apache.spark.sql.types.DateType
-      case r"\Aarray<(.*)${elemType}>\Z" => {
+      case r"\Aarray<(.+)${elemType}>\Z" =>
         org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
-      }
-      case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => {
+      case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" =>
         if (keyType != "string" && keyType != "character") {
           throw new IllegalArgumentException("Key type of a map must be string or character")
         }
         org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType))
-      }
+      case r"\Astruct<(.+)${fieldsStr}>\Z" =>
+        if (fieldsStr(fieldsStr.length - 1) == ',') {
+          throw new IllegalArgumentException(s"Invaid type $dataType")
+        }
+        val fields = fieldsStr.split(",")
+        val structFields = fields.map { field =>
+          field match {
+            case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" =>
+              createStructField(fieldName, fieldType, true)
+
+            case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
+          }
+        }
+        createStructType(structFields)
       case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
     }
   }
@@ -151,4 +165,27 @@ private[r] object SQLUtils {
       options: java.util.Map[String, String]): DataFrame = {
     sqlContext.read.format(source).schema(schema).options(options).load()
   }
+
+  def readSqlObject(dis: DataInputStream, dataType: Char): Object = {
+    dataType match {
+      case 's' =>
+        // Read StructType for DataFrame
+        val fields = SerDe.readList(dis).asInstanceOf[Array[Object]]
+        Row.fromSeq(fields)
+      case _ => null
+    }
+  }
+
+  def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = {
+    obj match {
+      // Handle struct type in DataFrame
+      case v: GenericRowWithSchema =>
+        dos.writeByte('s')
+        SerDe.writeObject(dos, v.schema.fieldNames)
+        SerDe.writeObject(dos, v.values)
+        true
+      case _ =>
+        false
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org