You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/11/13 04:16:02 UTC
spark git commit: [SPARK-20791][PYSPARK] Use Arrow to create Spark
DataFrame from Pandas
Repository: spark
Updated Branches:
refs/heads/master 3d90b2cb3 -> 209b9361a
[SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas
## What changes were proposed in this pull request?
This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default.
## How was this patch tested?
Added new unit test to create DataFrame with and without the optimization enabled, then compare results.
Author: Bryan Cutler <cu...@gmail.com>
Author: Takuya UESHIN <ue...@databricks.com>
Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/209b9361
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/209b9361
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/209b9361
Branch: refs/heads/master
Commit: 209b9361ac8a4410ff797cff1115e1888e2f7e66
Parents: 3d90b2c
Author: Bryan Cutler <cu...@gmail.com>
Authored: Mon Nov 13 13:16:01 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Mon Nov 13 13:16:01 2017 +0900
----------------------------------------------------------------------
python/pyspark/context.py | 28 +++---
python/pyspark/java_gateway.py | 1 +
python/pyspark/serializers.py | 10 ++-
python/pyspark/sql/session.py | 88 +++++++++++++++----
python/pyspark/sql/tests.py | 89 +++++++++++++++++---
python/pyspark/sql/types.py | 49 +++++++++++
.../spark/sql/api/python/PythonSQLUtils.scala | 18 ++++
.../sql/execution/arrow/ArrowConverters.scala | 14 +++
8 files changed, 254 insertions(+), 43 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a33f6dc..24905f1 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -475,24 +475,30 @@ class SparkContext(object):
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
- # Calling the Java parallelize() method with an ArrayList is too slow,
- # because it sends O(n) Py4J commands. As an alternative, serialized
- # objects are written to a file and loaded through textFile().
+
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+ serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+ jrdd = self._serialize_to_jvm(c, numSlices, serializer)
+ return RDD(jrdd, self, serializer)
+
+ def _serialize_to_jvm(self, data, parallelism, serializer):
+ """
+ Calling the Java parallelize() method with an ArrayList is too slow,
+ because it sends O(n) Py4J commands. As an alternative, serialized
+ objects are written to a file and loaded through textFile().
+ """
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
try:
- # Make sure we distribute data evenly if it's smaller than self.batchSize
- if "__len__" not in dir(c):
- c = list(c) # Make it a list so we can compute its length
- batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
- serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
- serializer.dump_stream(c, tempFile)
+ serializer.dump_stream(data, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
- jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return readRDDFromFile(self._jsc, tempFile.name, parallelism)
finally:
# readRDDFromFile eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)
- return RDD(jrdd, self, serializer)
def pickleFile(self, name, minPartitions=None):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3c783ae..3e704fe 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -121,6 +121,7 @@ def launch_gateway(conf=None):
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d7979f0..e0afdaf 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -214,6 +214,13 @@ class ArrowSerializer(FramedSerializer):
def _create_batch(series):
+ """
+ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
+
+ :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
+ :return: Arrow RecordBatch
+ """
+
from pyspark.sql.types import _check_series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
@@ -229,7 +236,8 @@ def _create_batch(series):
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _check_series_convert_timestamps_internal(s.fillna(0))\
.values.astype('datetime64[us]', copy=False)
- elif t == pa.date32():
+ # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
+ elif t is not None and t == pa.date32():
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
return s.dt.date
elif t is None or s.dtype == t.to_pandas_dtype():
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index d1d0b8b..589365b 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -25,7 +25,7 @@ if sys.version >= '3':
basestring = unicode = str
xrange = range
else:
- from itertools import imap as map
+ from itertools import izip as zip, imap as map
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
@@ -417,12 +417,12 @@ class SparkSession(object):
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema
- def _get_numpy_record_dtypes(self, rec):
+ def _get_numpy_record_dtype(self, rec):
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
- the dtypes of records so they can be properly loaded into Spark.
- :param rec: a numpy record to check dtypes
- :return corrected dtypes for a numpy.record or None if no correction needed
+ the dtypes of fields in a record so they can be properly loaded into Spark.
+ :param rec: a numpy record to check field dtypes
+ :return corrected dtype for a numpy.record or None if no correction needed
"""
import numpy as np
cur_dtypes = rec.dtype
@@ -438,28 +438,70 @@ class SparkSession(object):
curr_type = 'datetime64[us]'
has_rec_fix = True
record_type_list.append((str(col_names[i]), curr_type))
- return record_type_list if has_rec_fix else None
+ return np.dtype(record_type_list) if has_rec_fix else None
- def _convert_from_pandas(self, pdf, schema):
+ def _convert_from_pandas(self, pdf):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
- :return tuple of list of records and schema
+ :return list of records
"""
- # If no schema supplied by user then get the names of columns only
- if schema is None:
- schema = [str(x) for x in pdf.columns]
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
# Check if any columns need to be fixed for Spark to infer properly
if len(np_records) > 0:
- record_type_list = self._get_numpy_record_dtypes(np_records[0])
- if record_type_list is not None:
- return [r.astype(record_type_list).tolist() for r in np_records], schema
+ record_dtype = self._get_numpy_record_dtype(np_records[0])
+ if record_dtype is not None:
+ return [r.astype(record_dtype).tolist() for r in np_records]
# Convert list of numpy records to python lists
- return [r.tolist() for r in np_records], schema
+ return [r.tolist() for r in np_records]
+
+ def _create_from_pandas_with_arrow(self, pdf, schema):
+ """
+ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
+ to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
+ data types will be used to coerce the data in Pandas to Arrow conversion.
+ """
+ from pyspark.serializers import ArrowSerializer, _create_batch
+ from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
+ from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+
+ # Determine arrow types to coerce data when creating batches
+ if isinstance(schema, StructType):
+ arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
+ elif isinstance(schema, DataType):
+ raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
+ else:
+ # Any timestamps must be coerced to be compatible with Spark
+ arrow_types = [to_arrow_type(TimestampType())
+ if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
+ for t in pdf.dtypes]
+
+ # Slice the DataFrame to be batched
+ step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
+ pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
+
+ # Create Arrow record batches
+ batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
+ for pdf_slice in pdf_slices]
+
+ # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
+ if isinstance(schema, (list, tuple)):
+ struct = from_arrow_schema(batches[0].schema)
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ # Create the Spark DataFrame directly from the Arrow data and schema
+ jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
+ jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
+ jrdd, schema.json(), self._wrapped._jsqlContext)
+ df = DataFrame(jdf, self._wrapped)
+ df._schema = schema
+ return df
@since(2.0)
@ignore_unicode_prefix
@@ -557,7 +599,19 @@ class SparkSession(object):
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
- data, schema = self._convert_from_pandas(data, schema)
+
+ # If no schema supplied by user then get the names of columns only
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+
+ if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
+ and len(data) > 0:
+ try:
+ return self._create_from_pandas_with_arrow(data, schema)
+ except Exception as e:
+ warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
+ # Fallback to create DataFrame without arrow if raise some exception
+ data = self._convert_from_pandas(data)
if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
@@ -576,7 +630,7 @@ class SparkSession(object):
verify_func(obj)
return obj,
else:
- if isinstance(schema, list):
+ if isinstance(schema, (list, tuple)):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4819f62..6356d93 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3127,9 +3127,9 @@ class ArrowTests(ReusedSQLTestCase):
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
- cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
- ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
- ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
+ cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+ (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+ (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
@classmethod
def tearDownClass(cls):
@@ -3145,6 +3145,17 @@ class ArrowTests(ReusedSQLTestCase):
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
+ def create_pandas_data_frame(self):
+ import pandas as pd
+ import numpy as np
+ data_dict = {}
+ for j, name in enumerate(self.schema.names):
+ data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+ # need to convert these to numpy types first
+ data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
+ data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
+ return pd.DataFrame(data=data_dict)
+
def test_unsupported_datatype(self):
schema = StructType([StructField("decimal", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
@@ -3161,21 +3172,15 @@ class ArrowTests(ReusedSQLTestCase):
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
- pdf = df.toPandas()
- self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+ try:
+ pdf = df.toPandas()
+ finally:
+ self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
def test_pandas_round_trip(self):
- import pandas as pd
- import numpy as np
- data_dict = {}
- for j, name in enumerate(self.schema.names):
- data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
- # need to convert these to numpy types first
- data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
- data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
- pdf = pd.DataFrame(data=data_dict)
+ pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
@@ -3187,6 +3192,62 @@ class ArrowTests(ReusedSQLTestCase):
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)
+ def test_createDataFrame_toggle(self):
+ pdf = self.create_pandas_data_frame()
+ self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
+ try:
+ df_no_arrow = self.spark.createDataFrame(pdf)
+ finally:
+ self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+ df_arrow = self.spark.createDataFrame(pdf)
+ self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
+
+ def test_createDataFrame_with_schema(self):
+ pdf = self.create_pandas_data_frame()
+ df = self.spark.createDataFrame(pdf, schema=self.schema)
+ self.assertEquals(self.schema, df.schema)
+ pdf_arrow = df.toPandas()
+ self.assertFramesEqual(pdf_arrow, pdf)
+
+ def test_createDataFrame_with_incorrect_schema(self):
+ pdf = self.create_pandas_data_frame()
+ wrong_schema = StructType(list(reversed(self.schema)))
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
+ self.spark.createDataFrame(pdf, schema=wrong_schema)
+
+ def test_createDataFrame_with_names(self):
+ pdf = self.create_pandas_data_frame()
+ # Test that schema as a list of column names gets applied
+ df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
+ self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+ # Test that schema as tuple of column names gets applied
+ df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
+ self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+
+ def test_createDataFrame_with_single_data_type(self):
+ import pandas as pd
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
+ self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
+
+ def test_createDataFrame_does_not_modify_input(self):
+ # Some series get converted for Spark to consume, this makes sure input is unchanged
+ pdf = self.create_pandas_data_frame()
+ # Use a nanosecond value to make sure it is not truncated
+ pdf.ix[0, '7_timestamp_t'] = 1
+ # Integers with nulls will get NaNs filled with 0 and will be casted
+ pdf.ix[1, '2_int_t'] = None
+ pdf_copy = pdf.copy(deep=True)
+ self.spark.createDataFrame(pdf, schema=self.schema)
+ self.assertTrue(pdf.equals(pdf_copy))
+
+ def test_schema_conversion_roundtrip(self):
+ from pyspark.sql.types import from_arrow_schema, to_arrow_schema
+ arrow_schema = to_arrow_schema(self.schema)
+ schema_rt = from_arrow_schema(arrow_schema)
+ self.assertEquals(self.schema, schema_rt)
+
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedSQLTestCase):
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 7dd8fa0..fe62f60 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
return arrow_type
+def to_arrow_schema(schema):
+ """ Convert a schema from Spark to Arrow
+ """
+ import pyarrow as pa
+ fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
+ for field in schema]
+ return pa.schema(fields)
+
+
+def from_arrow_type(at):
+ """ Convert pyarrow type to Spark data type.
+ """
+ # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
+ import pyarrow as pa
+ if at == pa.bool_():
+ spark_type = BooleanType()
+ elif at == pa.int8():
+ spark_type = ByteType()
+ elif at == pa.int16():
+ spark_type = ShortType()
+ elif at == pa.int32():
+ spark_type = IntegerType()
+ elif at == pa.int64():
+ spark_type = LongType()
+ elif at == pa.float32():
+ spark_type = FloatType()
+ elif at == pa.float64():
+ spark_type = DoubleType()
+ elif type(at) == pa.DecimalType:
+ spark_type = DecimalType(precision=at.precision, scale=at.scale)
+ elif at == pa.string():
+ spark_type = StringType()
+ elif at == pa.date32():
+ spark_type = DateType()
+ elif type(at) == pa.TimestampType:
+ spark_type = TimestampType()
+ else:
+ raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
+ return spark_type
+
+
+def from_arrow_schema(arrow_schema):
+ """ Convert schema from Arrow to Spark.
+ """
+ return StructType(
+ [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
+ for field in arrow_schema])
+
+
def _check_dataframe_localize_timestamps(pdf):
"""
Convert timezone aware timestamps to timezone-naive in local time
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 4d5ce0b..b33760b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.api.python
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.DataType
private[sql] object PythonSQLUtils {
@@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils {
def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
}
+
+ /**
+ * Python Callable function to convert ArrowPayloads into a [[DataFrame]].
+ *
+ * @param payloadRDD A JavaRDD of ArrowPayloads.
+ * @param schemaString JSON Formatted Schema for ArrowPayloads.
+ * @param sqlContext The active [[SQLContext]].
+ * @return The converted [[DataFrame]].
+ */
+ def arrowPayloadToDataFrame(
+ payloadRDD: JavaRDD[Array[Byte]],
+ schemaString: String,
+ sqlContext: SQLContext): DataFrame = {
+ ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/209b9361/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 05ea151..3cafb34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
@@ -204,4 +206,16 @@ private[sql] object ArrowConverters {
reader.close()
}
}
+
+ private[sql] def toDataFrame(
+ payloadRDD: JavaRDD[Array[Byte]],
+ schemaString: String,
+ sqlContext: SQLContext): DataFrame = {
+ val rdd = payloadRDD.rdd.mapPartitions { iter =>
+ val context = TaskContext.get()
+ ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
+ }
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ sqlContext.internalCreateDataFrame(rdd, schema)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org