You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/13 00:06:48 UTC
[spark] branch master updated: [SPARK-39406][PYTHON] Accept NumPy array in createDataFrame
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5d393dde751 [SPARK-39406][PYTHON] Accept NumPy array in createDataFrame
5d393dde751 is described below
commit 5d393dde751181287a7fa8cc8f0f56bdb9d1d1ac
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Mon Jun 13 09:06:36 2022 +0900
[SPARK-39406][PYTHON] Accept NumPy array in createDataFrame
### What changes were proposed in this pull request?
Accept NunPy array in createDataFrame, with existing dtypes support.
Note that
- by the constraint of Spark DataFrame, we support 1-dimensional and 2-dimensional arrays only.
- full NunPy <> PySpark mappings will be introduced as a follow-up.
### Why are the changes needed?
As part of SPARK-39405 for NumPy support in SQL.
### Does this PR introduce _any_ user-facing change?
Yes, NumPy array is accepted in createDataFrame now:
Before:
```py
>>> spark.createDataFrame(np.array([[1, 2], [3, 4]]))
Traceback (most recent call last):
...
TypeError: Can not infer schema for type: <class 'numpy.ndarray'>
>>> spark.createDataFrame(np.array([0.1, 0.2]))
Traceback (most recent call last):
...
TypeError: Can not infer schema for type: <class 'numpy.float64'>
```
After:
```py
>>> spark.createDataFrame(np.array([[1, 2], [3, 4]])).show()
+---+---+
| _1| _2|
+---+---+
| 1| 2|
| 3| 4|
+---+---+
>>> spark.createDataFrame(np.array([0.1, 0.2])).show()
+-----+
|value|
+-----+
| 0.1|
| 0.2|
+-----+
```
### How was this patch tested?
Unit tests.
Closes #36793 from xinrong-databricks/createDataFrame2.
Authored-by: Xinrong Meng <xi...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/docs/source/getting_started/install.rst | 2 +-
python/pyspark/sql/pandas/_typing/__init__.pyi | 2 ++
python/pyspark/sql/session.py | 34 ++++++++++++++++-----
python/pyspark/sql/tests/test_arrow.py | 35 ++++++++++++++++++++--
python/setup.py | 1 +
.../org/apache/spark/sql/internal/SQLConf.scala | 3 +-
6 files changed, 65 insertions(+), 12 deletions(-)
diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index e5c1455da7a..afcdb7291c5 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -159,7 +159,7 @@ Package Minimum supported version Note
`py4j` 0.10.9.5 Required
`pandas` 1.0.5 Required for pandas API on Spark
`pyarrow` 1.0.0 Required for pandas API on Spark
-`numpy` 1.15 Required for pandas API on Spark and MLLib DataFrame-based API
+`numpy` 1.15 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL
============= ========================= ======================================
Note that PySpark requires Java 8 or later with ``JAVA_HOME`` properly set.
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 6ecd04f057e..27ac64a7238 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -32,9 +32,11 @@ from types import FunctionType
from pyspark.sql._typing import LiteralType
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
+from numpy import ndarray as NDArray
import pyarrow
+ArrayLike = NDArray
DataFrameLike = PandasDataFrame
SeriesLike = PandasSeries
DataFrameOrSeriesLike = Union[DataFrameLike, SeriesLike]
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 2ea8fa792e2..a4c36d719b2 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -63,7 +63,7 @@ from pyspark.sql.utils import install_exception_handler, is_timestamp_ntz_prefer
if TYPE_CHECKING:
from pyspark.sql._typing import AtomicValue, RowLike
from pyspark.sql.catalog import Catalog
- from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as PandasDataFrameLike
from pyspark.sql.streaming import StreamingQueryManager
from pyspark.sql.udf import UDFRegistration
@@ -837,13 +837,14 @@ class SparkSession(SparkConversionMixin):
def createDataFrame( # type: ignore[misc]
self,
- data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike"],
+ data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike", "ArrayLike"],
schema: Optional[Union[AtomicType, StructType, str]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
"""
- Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
+ Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`
+ or a :class:`numpy.ndarray`.
When ``schema`` is a list of column names, the type of each column
will be inferred from ``data``.
@@ -870,8 +871,8 @@ class SparkSession(SparkConversionMixin):
----------
data : :class:`RDD` or iterable
an RDD of any kind of SQL data representation (:class:`Row`,
- :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or
- :class:`pandas.DataFrame`.
+ :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`,
+ :class:`pandas.DataFrame` or :class:`numpy.ndarray`.
schema : :class:`pyspark.sql.types.DataType`, str or list, optional
a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
column names, default is None. The data type string format equals to
@@ -952,12 +953,31 @@ class SparkSession(SparkConversionMixin):
schema = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]
try:
- import pandas
+ import pandas as pd
has_pandas = True
except Exception:
has_pandas = False
- if has_pandas and isinstance(data, pandas.DataFrame):
+
+ try:
+ import numpy as np
+
+ has_numpy = True
+ except Exception:
+ has_numpy = False
+
+ if has_numpy and isinstance(data, np.ndarray):
+ # `data` of numpy.ndarray type will be converted to a pandas DataFrame,
+ # so pandas is required.
+ from pyspark.sql.pandas.utils import require_minimum_pandas_version
+
+ require_minimum_pandas_version()
+ if data.ndim not in [1, 2]:
+ raise ValueError("NumPy array input should be of 1 or 2 dimensions.")
+ column_names = ["value"] if data.ndim == 1 else ["_1", "_2"]
+ data = pd.DataFrame(data, columns=column_names)
+
+ if has_pandas and isinstance(data, pd.DataFrame):
# Create a DataFrame from pandas DataFrame.
return super(SparkSession, self).createDataFrame( # type: ignore[call-overload]
data, schema, samplingRatio, verifySchema
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index ff42ade1407..b737848b11a 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -24,6 +24,8 @@ import warnings
from distutils.version import LooseVersion
from typing import cast
+import numpy as np
+
from pyspark import SparkContext, SparkConf
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import rand, udf, assert_true, lit
@@ -179,6 +181,15 @@ class ArrowTests(ReusedSQLTestCase):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)
+ @property
+ def create_np_arrs(self):
+ return [
+ np.array([1, 2]), # dtype('int64')
+ np.array([0.1, 0.2]), # dtype('float64')
+ np.array([[1, 2], [3, 4]]), # dtype('int64')
+ np.array([[0.1, 0.2], [0.3, 0.4]]), # dtype('float64')
+ ]
+
def test_toPandas_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
@@ -391,11 +402,11 @@ class ArrowTests(ReusedSQLTestCase):
with self.assertRaisesRegex(Exception, "My error"):
df.toPandas()
- def _createDataFrame_toggle(self, pdf, schema=None):
+ def _createDataFrame_toggle(self, data, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
- df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
+ df_no_arrow = self.spark.createDataFrame(data, schema=schema)
- df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+ df_arrow = self.spark.createDataFrame(data, schema=schema)
return df_no_arrow, df_arrow
@@ -495,6 +506,24 @@ class ArrowTests(ReusedSQLTestCase):
schema_rt = from_arrow_schema(arrow_schema)
self.assertEqual(self.schema, schema_rt)
+ def test_createDataFrame_with_ndarray(self):
+ arrs = self.create_np_arrs
+ collected_dtypes = [
+ ([Row(value=1), Row(value=2)], [("value", "bigint")]),
+ ([Row(value=0.1), Row(value=0.2)], [("value", "double")]),
+ ([Row(_1=1, _2=2), Row(_1=3, _2=4)], [("_1", "bigint"), ("_2", "bigint")]),
+ ([Row(_1=0.1, _2=0.2), Row(_1=0.3, _2=0.4)], [("_1", "double"), ("_2", "double")]),
+ ]
+ for arr, [collected, dtypes] in zip(arrs, collected_dtypes):
+ df, df_arrow = self._createDataFrame_toggle(arr)
+ self.assertEqual(df.dtypes, dtypes)
+ self.assertEqual(df_arrow.dtypes, dtypes)
+ self.assertEqual(df.collect(), collected)
+ self.assertEqual(df_arrow.collect(), collected)
+
+ with self.assertRaisesRegex(ValueError, "NumPy array input should be of 1 or 2 dimensions"):
+ self.spark.createDataFrame(np.array(0))
+
def test_createDataFrame_with_array_type(self):
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [["x", "y"], ["y", "z"]]})
df, df_arrow = self._createDataFrame_toggle(pdf)
diff --git a/python/setup.py b/python/setup.py
index 6128b206223..061dc9d663d 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -266,6 +266,7 @@ try:
'sql': [
'pandas>=%s' % _minimum_pandas_version,
'pyarrow>=%s' % _minimum_pyarrow_version,
+ 'numpy>=1.15',
],
'pandas_on_spark': [
'pandas>=%s' % _minimum_pandas_version,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5e1f3956159..b8a752e90ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2557,8 +2557,9 @@ object SQLConf {
buildConf("spark.sql.execution.arrow.pyspark.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers in PySpark. " +
"This optimization applies to: " +
- "1. pyspark.sql.DataFrame.toPandas " +
+ "1. pyspark.sql.DataFrame.toPandas. " +
"2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " +
+ "or a NumPy ndarray. " +
"The following data types are unsupported: " +
"ArrayType of TimestampType, and nested StructType.")
.version("3.0.0")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org