You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/11/13 04:16:02 UTC

spark git commit: [SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas

Repository: spark
Updated Branches:
  refs/heads/master 3d90b2cb3 -> 209b9361a


[SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas

## What changes were proposed in this pull request?

This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default.

## How was this patch tested?

Added new unit test to create DataFrame with and without the optimization enabled, then compare results.

Author: Bryan Cutler <cu...@gmail.com>
Author: Takuya UESHIN <ue...@databricks.com>

Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/209b9361
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/209b9361
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/209b9361

Branch: refs/heads/master
Commit: 209b9361ac8a4410ff797cff1115e1888e2f7e66
Parents: 3d90b2c
Author: Bryan Cutler <cu...@gmail.com>
Authored: Mon Nov 13 13:16:01 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Mon Nov 13 13:16:01 2017 +0900

----------------------------------------------------------------------
 python/pyspark/context.py                       | 28 +++---
 python/pyspark/java_gateway.py                  |  1 +
 python/pyspark/serializers.py                   | 10 ++-
 python/pyspark/sql/session.py                   | 88 +++++++++++++++----
 python/pyspark/sql/tests.py                     | 89 +++++++++++++++++---
 python/pyspark/sql/types.py                     | 49 +++++++++++
 .../spark/sql/api/python/PythonSQLUtils.scala   | 18 ++++
 .../sql/execution/arrow/ArrowConverters.scala   | 14 +++
 8 files changed, 254 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a33f6dc..24905f1 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -475,24 +475,30 @@ class SparkContext(object):
                 return xrange(getStart(split), getStart(split + 1), step)
 
             return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
-        # Calling the Java parallelize() method with an ArrayList is too slow,
-        # because it sends O(n) Py4J commands.  As an alternative, serialized
-        # objects are written to a file and loaded through textFile().
+
+        # Make sure we distribute data evenly if it's smaller than self.batchSize
+        if "__len__" not in dir(c):
+            c = list(c)    # Make it a list so we can compute its length
+        batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+        serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+        jrdd = self._serialize_to_jvm(c, numSlices, serializer)
+        return RDD(jrdd, self, serializer)
+
+    def _serialize_to_jvm(self, data, parallelism, serializer):
+        """
+        Calling the Java parallelize() method with an ArrayList is too slow,
+        because it sends O(n) Py4J commands.  As an alternative, serialized
+        objects are written to a file and loaded through textFile().
+        """
         tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
         try:
-            # Make sure we distribute data evenly if it's smaller than self.batchSize
-            if "__len__" not in dir(c):
-                c = list(c)    # Make it a list so we can compute its length
-            batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
-            serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
-            serializer.dump_stream(c, tempFile)
+            serializer.dump_stream(data, tempFile)
             tempFile.close()
             readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
-            jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+            return readRDDFromFile(self._jsc, tempFile.name, parallelism)
         finally:
             # readRDDFromFile eagerily reads the file so we can delete right after.
             os.unlink(tempFile.name)
-        return RDD(jrdd, self, serializer)
 
     def pickleFile(self, name, minPartitions=None):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3c783ae..3e704fe 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -121,6 +121,7 @@ def launch_gateway(conf=None):
     java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
     # TODO(davies): move into sql
     java_import(gateway.jvm, "org.apache.spark.sql.*")
+    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
     java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
     java_import(gateway.jvm, "scala.Tuple2")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d7979f0..e0afdaf 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -214,6 +214,13 @@ class ArrowSerializer(FramedSerializer):
 
 
 def _create_batch(series):
+    """
+    Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
+
+    :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
+    :return: Arrow RecordBatch
+    """
+
     from pyspark.sql.types import _check_series_convert_timestamps_internal
     import pyarrow as pa
     # Make input conform to [(series1, type1), (series2, type2), ...]
@@ -229,7 +236,8 @@ def _create_batch(series):
             # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
             return _check_series_convert_timestamps_internal(s.fillna(0))\
                 .values.astype('datetime64[us]', copy=False)
-        elif t == pa.date32():
+        # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
+        elif t is not None and t == pa.date32():
             # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
             return s.dt.date
         elif t is None or s.dtype == t.to_pandas_dtype():

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index d1d0b8b..589365b 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -25,7 +25,7 @@ if sys.version >= '3':
     basestring = unicode = str
     xrange = range
 else:
-    from itertools import imap as map
+    from itertools import izip as zip, imap as map
 
 from pyspark import since
 from pyspark.rdd import RDD, ignore_unicode_prefix
@@ -417,12 +417,12 @@ class SparkSession(object):
         data = [schema.toInternal(row) for row in data]
         return self._sc.parallelize(data), schema
 
-    def _get_numpy_record_dtypes(self, rec):
+    def _get_numpy_record_dtype(self, rec):
         """
         Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
-        the dtypes of records so they can be properly loaded into Spark.
-        :param rec: a numpy record to check dtypes
-        :return corrected dtypes for a numpy.record or None if no correction needed
+        the dtypes of fields in a record so they can be properly loaded into Spark.
+        :param rec: a numpy record to check field dtypes
+        :return corrected dtype for a numpy.record or None if no correction needed
         """
         import numpy as np
         cur_dtypes = rec.dtype
@@ -438,28 +438,70 @@ class SparkSession(object):
                 curr_type = 'datetime64[us]'
                 has_rec_fix = True
             record_type_list.append((str(col_names[i]), curr_type))
-        return record_type_list if has_rec_fix else None
+        return np.dtype(record_type_list) if has_rec_fix else None
 
-    def _convert_from_pandas(self, pdf, schema):
+    def _convert_from_pandas(self, pdf):
         """
          Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
-         :return tuple of list of records and schema
+         :return list of records
         """
-        # If no schema supplied by user then get the names of columns only
-        if schema is None:
-            schema = [str(x) for x in pdf.columns]
 
         # Convert pandas.DataFrame to list of numpy records
         np_records = pdf.to_records(index=False)
 
         # Check if any columns need to be fixed for Spark to infer properly
         if len(np_records) > 0:
-            record_type_list = self._get_numpy_record_dtypes(np_records[0])
-            if record_type_list is not None:
-                return [r.astype(record_type_list).tolist() for r in np_records], schema
+            record_dtype = self._get_numpy_record_dtype(np_records[0])
+            if record_dtype is not None:
+                return [r.astype(record_dtype).tolist() for r in np_records]
 
         # Convert list of numpy records to python lists
-        return [r.tolist() for r in np_records], schema
+        return [r.tolist() for r in np_records]
+
+    def _create_from_pandas_with_arrow(self, pdf, schema):
+        """
+        Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
+        to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
+        data types will be used to coerce the data in Pandas to Arrow conversion.
+        """
+        from pyspark.serializers import ArrowSerializer, _create_batch
+        from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
+        from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+
+        # Determine arrow types to coerce data when creating batches
+        if isinstance(schema, StructType):
+            arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
+        elif isinstance(schema, DataType):
+            raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
+        else:
+            # Any timestamps must be coerced to be compatible with Spark
+            arrow_types = [to_arrow_type(TimestampType())
+                           if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
+                           for t in pdf.dtypes]
+
+        # Slice the DataFrame to be batched
+        step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # round int up
+        pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
+
+        # Create Arrow record batches
+        batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
+                   for pdf_slice in pdf_slices]
+
+        # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
+        if isinstance(schema, (list, tuple)):
+            struct = from_arrow_schema(batches[0].schema)
+            for i, name in enumerate(schema):
+                struct.fields[i].name = name
+                struct.names[i] = name
+            schema = struct
+
+        # Create the Spark DataFrame directly from the Arrow data and schema
+        jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
+        jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
+            jrdd, schema.json(), self._wrapped._jsqlContext)
+        df = DataFrame(jdf, self._wrapped)
+        df._schema = schema
+        return df
 
     @since(2.0)
     @ignore_unicode_prefix
@@ -557,7 +599,19 @@ class SparkSession(object):
         except Exception:
             has_pandas = False
         if has_pandas and isinstance(data, pandas.DataFrame):
-            data, schema = self._convert_from_pandas(data, schema)
+
+            # If no schema supplied by user then get the names of columns only
+            if schema is None:
+                schema = [str(x) for x in data.columns]
+
+            if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
+                    and len(data) > 0:
+                try:
+                    return self._create_from_pandas_with_arrow(data, schema)
+                except Exception as e:
+                    warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
+                    # Fallback to create DataFrame without arrow if raise some exception
+            data = self._convert_from_pandas(data)
 
         if isinstance(schema, StructType):
             verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
@@ -576,7 +630,7 @@ class SparkSession(object):
                 verify_func(obj)
                 return obj,
         else:
-            if isinstance(schema, list):
+            if isinstance(schema, (list, tuple)):
                 schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
             prepare = lambda obj: obj
 

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4819f62..6356d93 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3127,9 +3127,9 @@ class ArrowTests(ReusedSQLTestCase):
             StructField("5_double_t", DoubleType(), True),
             StructField("6_date_t", DateType(), True),
             StructField("7_timestamp_t", TimestampType(), True)])
-        cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
-                    ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
-                    ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
+        cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+                    (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+                    (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
 
     @classmethod
     def tearDownClass(cls):
@@ -3145,6 +3145,17 @@ class ArrowTests(ReusedSQLTestCase):
                ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
         self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
 
+    def create_pandas_data_frame(self):
+        import pandas as pd
+        import numpy as np
+        data_dict = {}
+        for j, name in enumerate(self.schema.names):
+            data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+        # need to convert these to numpy types first
+        data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
+        data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
+        return pd.DataFrame(data=data_dict)
+
     def test_unsupported_datatype(self):
         schema = StructType([StructField("decimal", DecimalType(), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
@@ -3161,21 +3172,15 @@ class ArrowTests(ReusedSQLTestCase):
     def test_toPandas_arrow_toggle(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
-        pdf = df.toPandas()
-        self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+        try:
+            pdf = df.toPandas()
+        finally:
+            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
         pdf_arrow = df.toPandas()
         self.assertFramesEqual(pdf_arrow, pdf)
 
     def test_pandas_round_trip(self):
-        import pandas as pd
-        import numpy as np
-        data_dict = {}
-        for j, name in enumerate(self.schema.names):
-            data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
-        # need to convert these to numpy types first
-        data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
-        data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
-        pdf = pd.DataFrame(data=data_dict)
+        pdf = self.create_pandas_data_frame()
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf_arrow = df.toPandas()
         self.assertFramesEqual(pdf_arrow, pdf)
@@ -3187,6 +3192,62 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEqual(pdf.columns[0], "i")
         self.assertTrue(pdf.empty)
 
+    def test_createDataFrame_toggle(self):
+        pdf = self.create_pandas_data_frame()
+        self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
+        try:
+            df_no_arrow = self.spark.createDataFrame(pdf)
+        finally:
+            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+        df_arrow = self.spark.createDataFrame(pdf)
+        self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
+
+    def test_createDataFrame_with_schema(self):
+        pdf = self.create_pandas_data_frame()
+        df = self.spark.createDataFrame(pdf, schema=self.schema)
+        self.assertEquals(self.schema, df.schema)
+        pdf_arrow = df.toPandas()
+        self.assertFramesEqual(pdf_arrow, pdf)
+
+    def test_createDataFrame_with_incorrect_schema(self):
+        pdf = self.create_pandas_data_frame()
+        wrong_schema = StructType(list(reversed(self.schema)))
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
+                self.spark.createDataFrame(pdf, schema=wrong_schema)
+
+    def test_createDataFrame_with_names(self):
+        pdf = self.create_pandas_data_frame()
+        # Test that schema as a list of column names gets applied
+        df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
+        self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+        # Test that schema as tuple of column names gets applied
+        df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
+        self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+
+    def test_createDataFrame_with_single_data_type(self):
+        import pandas as pd
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
+                self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
+
+    def test_createDataFrame_does_not_modify_input(self):
+        # Some series get converted for Spark to consume, this makes sure input is unchanged
+        pdf = self.create_pandas_data_frame()
+        # Use a nanosecond value to make sure it is not truncated
+        pdf.ix[0, '7_timestamp_t'] = 1
+        # Integers with nulls will get NaNs filled with 0 and will be casted
+        pdf.ix[1, '2_int_t'] = None
+        pdf_copy = pdf.copy(deep=True)
+        self.spark.createDataFrame(pdf, schema=self.schema)
+        self.assertTrue(pdf.equals(pdf_copy))
+
+    def test_schema_conversion_roundtrip(self):
+        from pyspark.sql.types import from_arrow_schema, to_arrow_schema
+        arrow_schema = to_arrow_schema(self.schema)
+        schema_rt = from_arrow_schema(arrow_schema)
+        self.assertEquals(self.schema, schema_rt)
+
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
 class VectorizedUDFTests(ReusedSQLTestCase):

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 7dd8fa0..fe62f60 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
     return arrow_type
 
 
+def to_arrow_schema(schema):
+    """ Convert a schema from Spark to Arrow
+    """
+    import pyarrow as pa
+    fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
+              for field in schema]
+    return pa.schema(fields)
+
+
+def from_arrow_type(at):
+    """ Convert pyarrow type to Spark data type.
+    """
+    # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
+    import pyarrow as pa
+    if at == pa.bool_():
+        spark_type = BooleanType()
+    elif at == pa.int8():
+        spark_type = ByteType()
+    elif at == pa.int16():
+        spark_type = ShortType()
+    elif at == pa.int32():
+        spark_type = IntegerType()
+    elif at == pa.int64():
+        spark_type = LongType()
+    elif at == pa.float32():
+        spark_type = FloatType()
+    elif at == pa.float64():
+        spark_type = DoubleType()
+    elif type(at) == pa.DecimalType:
+        spark_type = DecimalType(precision=at.precision, scale=at.scale)
+    elif at == pa.string():
+        spark_type = StringType()
+    elif at == pa.date32():
+        spark_type = DateType()
+    elif type(at) == pa.TimestampType:
+        spark_type = TimestampType()
+    else:
+        raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
+    return spark_type
+
+
+def from_arrow_schema(arrow_schema):
+    """ Convert schema from Arrow to Spark.
+    """
+    return StructType(
+        [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
+         for field in arrow_schema])
+
+
 def _check_dataframe_localize_timestamps(pdf):
     """
     Convert timezone aware timestamps to timezone-naive in local time

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 4d5ce0b..b33760b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.api.python
 
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.types.DataType
 
 private[sql] object PythonSQLUtils {
@@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils {
   def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
     FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
   }
+
+  /**
+   * Python Callable function to convert ArrowPayloads into a [[DataFrame]].
+   *
+   * @param payloadRDD A JavaRDD of ArrowPayloads.
+   * @param schemaString JSON Formatted Schema for ArrowPayloads.
+   * @param sqlContext The active [[SQLContext]].
+   * @return The converted [[DataFrame]].
+   */
+  def arrowPayloadToDataFrame(
+      payloadRDD: JavaRDD[Array[Byte]],
+      schemaString: String,
+      sqlContext: SQLContext): DataFrame = {
+    ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 05ea151..3cafb34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch
 import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
 
 import org.apache.spark.TaskContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
 import org.apache.spark.sql.types._
@@ -204,4 +206,16 @@ private[sql] object ArrowConverters {
       reader.close()
     }
   }
+
+  private[sql] def toDataFrame(
+      payloadRDD: JavaRDD[Array[Byte]],
+      schemaString: String,
+      sqlContext: SQLContext): DataFrame = {
+    val rdd = payloadRDD.rdd.mapPartitions { iter =>
+      val context = TaskContext.get()
+      ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
+    }
+    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+    sqlContext.internalCreateDataFrame(rdd, schema)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org