You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/13 00:06:48 UTC

[spark] branch master updated: [SPARK-39406][PYTHON] Accept NumPy array in createDataFrame

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d393dde751 [SPARK-39406][PYTHON] Accept NumPy array in createDataFrame
5d393dde751 is described below

commit 5d393dde751181287a7fa8cc8f0f56bdb9d1d1ac
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Mon Jun 13 09:06:36 2022 +0900

    [SPARK-39406][PYTHON] Accept NumPy array in createDataFrame
    
    ### What changes were proposed in this pull request?
    Accept NunPy array in createDataFrame, with existing dtypes support.
    
    Note that
    -  by the constraint of Spark DataFrame, we support 1-dimensional and 2-dimensional arrays only.
    - full NunPy <> PySpark mappings will be introduced as a follow-up.
    
    ### Why are the changes needed?
    As part of SPARK-39405 for NumPy support in SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, NumPy array is accepted in createDataFrame now:
    
    Before:
    ```py
    >>> spark.createDataFrame(np.array([[1, 2], [3, 4]]))
    Traceback (most recent call last):
    ...
    TypeError: Can not infer schema for type: <class 'numpy.ndarray'>
    
    >>> spark.createDataFrame(np.array([0.1, 0.2]))
    Traceback (most recent call last):
    ...
    TypeError: Can not infer schema for type: <class 'numpy.float64'>
    ```
    
    After:
    ```py
    >>> spark.createDataFrame(np.array([[1, 2], [3, 4]])).show()
    +---+---+
    | _1| _2|
    +---+---+
    |  1|  2|
    |  3|  4|
    +---+---+
    
    >>> spark.createDataFrame(np.array([0.1, 0.2])).show()
    +-----+
    |value|
    +-----+
    |  0.1|
    |  0.2|
    +-----+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #36793 from xinrong-databricks/createDataFrame2.
    
    Authored-by: Xinrong Meng <xi...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/docs/source/getting_started/install.rst     |  2 +-
 python/pyspark/sql/pandas/_typing/__init__.pyi     |  2 ++
 python/pyspark/sql/session.py                      | 34 ++++++++++++++++-----
 python/pyspark/sql/tests/test_arrow.py             | 35 ++++++++++++++++++++--
 python/setup.py                                    |  1 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  3 +-
 6 files changed, 65 insertions(+), 12 deletions(-)

diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index e5c1455da7a..afcdb7291c5 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -159,7 +159,7 @@ Package       Minimum supported version Note
 `py4j`        0.10.9.5                  Required
 `pandas`      1.0.5                     Required for pandas API on Spark
 `pyarrow`     1.0.0                     Required for pandas API on Spark
-`numpy`       1.15                      Required for pandas API on Spark and MLLib DataFrame-based API
+`numpy`       1.15                      Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL
 ============= ========================= ======================================
 
 Note that PySpark requires Java 8 or later with ``JAVA_HOME`` properly set.  
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 6ecd04f057e..27ac64a7238 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -32,9 +32,11 @@ from types import FunctionType
 from pyspark.sql._typing import LiteralType
 from pandas.core.frame import DataFrame as PandasDataFrame
 from pandas.core.series import Series as PandasSeries
+from numpy import ndarray as NDArray
 
 import pyarrow
 
+ArrayLike = NDArray
 DataFrameLike = PandasDataFrame
 SeriesLike = PandasSeries
 DataFrameOrSeriesLike = Union[DataFrameLike, SeriesLike]
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 2ea8fa792e2..a4c36d719b2 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -63,7 +63,7 @@ from pyspark.sql.utils import install_exception_handler, is_timestamp_ntz_prefer
 if TYPE_CHECKING:
     from pyspark.sql._typing import AtomicValue, RowLike
     from pyspark.sql.catalog import Catalog
-    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+    from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as PandasDataFrameLike
     from pyspark.sql.streaming import StreamingQueryManager
     from pyspark.sql.udf import UDFRegistration
 
@@ -837,13 +837,14 @@ class SparkSession(SparkConversionMixin):
 
     def createDataFrame(  # type: ignore[misc]
         self,
-        data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike"],
+        data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike", "ArrayLike"],
         schema: Optional[Union[AtomicType, StructType, str]] = None,
         samplingRatio: Optional[float] = None,
         verifySchema: bool = True,
     ) -> DataFrame:
         """
-        Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
+        Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`
+        or a :class:`numpy.ndarray`.
 
         When ``schema`` is a list of column names, the type of each column
         will be inferred from ``data``.
@@ -870,8 +871,8 @@ class SparkSession(SparkConversionMixin):
         ----------
         data : :class:`RDD` or iterable
             an RDD of any kind of SQL data representation (:class:`Row`,
-            :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or
-            :class:`pandas.DataFrame`.
+            :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`,
+            :class:`pandas.DataFrame` or :class:`numpy.ndarray`.
         schema : :class:`pyspark.sql.types.DataType`, str or list, optional
             a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
             column names, default is None.  The data type string format equals to
@@ -952,12 +953,31 @@ class SparkSession(SparkConversionMixin):
             schema = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]
 
         try:
-            import pandas
+            import pandas as pd
 
             has_pandas = True
         except Exception:
             has_pandas = False
-        if has_pandas and isinstance(data, pandas.DataFrame):
+
+        try:
+            import numpy as np
+
+            has_numpy = True
+        except Exception:
+            has_numpy = False
+
+        if has_numpy and isinstance(data, np.ndarray):
+            # `data` of numpy.ndarray type will be converted to a pandas DataFrame,
+            # so pandas is required.
+            from pyspark.sql.pandas.utils import require_minimum_pandas_version
+
+            require_minimum_pandas_version()
+            if data.ndim not in [1, 2]:
+                raise ValueError("NumPy array input should be of 1 or 2 dimensions.")
+            column_names = ["value"] if data.ndim == 1 else ["_1", "_2"]
+            data = pd.DataFrame(data, columns=column_names)
+
+        if has_pandas and isinstance(data, pd.DataFrame):
             # Create a DataFrame from pandas DataFrame.
             return super(SparkSession, self).createDataFrame(  # type: ignore[call-overload]
                 data, schema, samplingRatio, verifySchema
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index ff42ade1407..b737848b11a 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -24,6 +24,8 @@ import warnings
 from distutils.version import LooseVersion
 from typing import cast
 
+import numpy as np
+
 from pyspark import SparkContext, SparkConf
 from pyspark.sql import Row, SparkSession
 from pyspark.sql.functions import rand, udf, assert_true, lit
@@ -179,6 +181,15 @@ class ArrowTests(ReusedSQLTestCase):
         data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
         return pd.DataFrame(data=data_dict)
 
+    @property
+    def create_np_arrs(self):
+        return [
+            np.array([1, 2]),  # dtype('int64')
+            np.array([0.1, 0.2]),  # dtype('float64')
+            np.array([[1, 2], [3, 4]]),  # dtype('int64')
+            np.array([[0.1, 0.2], [0.3, 0.4]]),  # dtype('float64')
+        ]
+
     def test_toPandas_fallback_enabled(self):
         ts = datetime.datetime(2015, 11, 1, 0, 30)
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
@@ -391,11 +402,11 @@ class ArrowTests(ReusedSQLTestCase):
             with self.assertRaisesRegex(Exception, "My error"):
                 df.toPandas()
 
-    def _createDataFrame_toggle(self, pdf, schema=None):
+    def _createDataFrame_toggle(self, data, schema=None):
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
-            df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
+            df_no_arrow = self.spark.createDataFrame(data, schema=schema)
 
-        df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+        df_arrow = self.spark.createDataFrame(data, schema=schema)
 
         return df_no_arrow, df_arrow
 
@@ -495,6 +506,24 @@ class ArrowTests(ReusedSQLTestCase):
         schema_rt = from_arrow_schema(arrow_schema)
         self.assertEqual(self.schema, schema_rt)
 
+    def test_createDataFrame_with_ndarray(self):
+        arrs = self.create_np_arrs
+        collected_dtypes = [
+            ([Row(value=1), Row(value=2)], [("value", "bigint")]),
+            ([Row(value=0.1), Row(value=0.2)], [("value", "double")]),
+            ([Row(_1=1, _2=2), Row(_1=3, _2=4)], [("_1", "bigint"), ("_2", "bigint")]),
+            ([Row(_1=0.1, _2=0.2), Row(_1=0.3, _2=0.4)], [("_1", "double"), ("_2", "double")]),
+        ]
+        for arr, [collected, dtypes] in zip(arrs, collected_dtypes):
+            df, df_arrow = self._createDataFrame_toggle(arr)
+            self.assertEqual(df.dtypes, dtypes)
+            self.assertEqual(df_arrow.dtypes, dtypes)
+            self.assertEqual(df.collect(), collected)
+            self.assertEqual(df_arrow.collect(), collected)
+
+        with self.assertRaisesRegex(ValueError, "NumPy array input should be of 1 or 2 dimensions"):
+            self.spark.createDataFrame(np.array(0))
+
     def test_createDataFrame_with_array_type(self):
         pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [["x", "y"], ["y", "z"]]})
         df, df_arrow = self._createDataFrame_toggle(pdf)
diff --git a/python/setup.py b/python/setup.py
index 6128b206223..061dc9d663d 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -266,6 +266,7 @@ try:
             'sql': [
                 'pandas>=%s' % _minimum_pandas_version,
                 'pyarrow>=%s' % _minimum_pyarrow_version,
+                'numpy>=1.15',
             ],
             'pandas_on_spark': [
                 'pandas>=%s' % _minimum_pandas_version,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5e1f3956159..b8a752e90ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2557,8 +2557,9 @@ object SQLConf {
     buildConf("spark.sql.execution.arrow.pyspark.enabled")
       .doc("When true, make use of Apache Arrow for columnar data transfers in PySpark. " +
         "This optimization applies to: " +
-        "1. pyspark.sql.DataFrame.toPandas " +
+        "1. pyspark.sql.DataFrame.toPandas. " +
         "2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " +
+        "or a NumPy ndarray. " +
         "The following data types are unsupported: " +
         "ArrayType of TimestampType, and nested StructType.")
       .version("3.0.0")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org