You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/31 03:37:52 UTC

[spark] branch master updated: [SPARK-40271][PYTHON] Support list type for `pyspark.sql.functions.lit`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 65d89f8e897 [SPARK-40271][PYTHON] Support list type for `pyspark.sql.functions.lit`
65d89f8e897 is described below

commit 65d89f8e897449f7f026144a76328ff545fecde2
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Aug 31 11:37:20 2022 +0800

    [SPARK-40271][PYTHON] Support list type for `pyspark.sql.functions.lit`
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support `list` type for `pyspark.sql.functions.lit`.
    
    ### Why are the changes needed?
    
    To improve the API usability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, now the `list` type is available for `pyspark.sql.functions.list` as below:
    
    - Before
    ```python
    >>> spark.range(3).withColumn("c", lit([1,2,3])).show()
    Traceback (most recent call last):
    ...
    : org.apache.spark.SparkRuntimeException: [UNSUPPORTED_FEATURE.LITERAL_TYPE] The feature is not supported: Literal for '[1, 2, 3]' of class java.util.ArrayList.
            at org.apache.spark.sql.errors.QueryExecutionErrors$.literalTypeUnsupportedError(QueryExecutionErrors.scala:302)
            at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:100)
            at org.apache.spark.sql.functions$.lit(functions.scala:125)
            at org.apache.spark.sql.functions.lit(functions.scala)
            at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
            at java.base/java.lang.reflect.Method.invoke(Method.java:577)
            at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
            at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
            at py4j.Gateway.invoke(Gateway.java:282)
            at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
            at py4j.commands.CallCommand.execute(CallCommand.java:79)
            at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
            at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
            at java.base/java.lang.Thread.run(Thread.java:833)
    ```
    
    - After
    ```python
    >>> spark.range(3).withColumn("c", lit([1,2,3])).show()
    +---+---------+
    | id|        c|
    +---+---------+
    |  0|[1, 2, 3]|
    |  1|[1, 2, 3]|
    |  2|[1, 2, 3]|
    +---+---------+
    ```
    
    ### How was this patch tested?
    
    Added doctest & unit test.
    
    Closes #37722 from itholic/SPARK-40271.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/functions.py            | 23 +++++++++++++++++++++--
 python/pyspark/sql/tests/test_functions.py | 26 ++++++++++++++++++++++++++
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 03c16db602f..e7a7a1b37f3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -131,10 +131,13 @@ def lit(col: Any) -> Column:
 
     Parameters
     ----------
-    col : :class:`~pyspark.sql.Column` or Python primitive type.
+    col : :class:`~pyspark.sql.Column`, str, int, float, bool or list.
         the value to make it as a PySpark literal. If a column is passed,
         it returns the column as is.
 
+        .. versionchanged:: 3.4.0
+            Since 3.4.0, it supports the list type.
+
     Returns
     -------
     :class:`~pyspark.sql.Column`
@@ -149,8 +152,24 @@ def lit(col: Any) -> Column:
     +------+---+
     |     5|  0|
     +------+---+
+
+    Create a literal from a list.
+
+    >>> spark.range(1).select(lit([1, 2, 3])).show()
+    +--------------+
+    |array(1, 2, 3)|
+    +--------------+
+    |     [1, 2, 3]|
+    +--------------+
     """
-    return col if isinstance(col, Column) else _invoke_function("lit", col)
+    if isinstance(col, Column):
+        return col
+    elif isinstance(col, list):
+        if any(isinstance(c, Column) for c in col):
+            raise ValueError("lit does not allow for list of Columns")
+        return array(*[lit(item) for item in col])
+    else:
+        return _invoke_function("lit", col)
 
 
 def col(col: str) -> Column:
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 102ebef8317..1d02a540558 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -962,6 +962,32 @@ class FunctionsTests(ReusedSQLTestCase):
         actual = self.spark.range(1).select(lit(td)).first()[0]
         self.assertEqual(actual, td)
 
+    def test_lit_list(self):
+        # SPARK-40271: added list type supporting
+        test_list = [1, 2, 3]
+        expected = [1, 2, 3]
+        actual = self.spark.range(1).select(lit(test_list)).first()[0]
+        self.assertEqual(actual, expected)
+
+        test_list = [[1, 2, 3], [3, 4]]
+        expected = [[1, 2, 3], [3, 4]]
+        actual = self.spark.range(1).select(lit(test_list)).first()[0]
+        self.assertEqual(actual, expected)
+
+        test_list = ["a", 1, None, 1.0]
+        expected = ["a", "1", None, "1.0"]
+        actual = self.spark.range(1).select(lit(test_list)).first()[0]
+        self.assertEqual(actual, expected)
+
+        test_list = [["a", 1, None, 1.0], [1, None, "b"]]
+        expected = [["a", "1", None, "1.0"], ["1", None, "b"]]
+        actual = self.spark.range(1).select(lit(test_list)).first()[0]
+        self.assertEqual(actual, expected)
+
+        df = self.spark.range(10)
+        with self.assertRaisesRegex(ValueError, "lit does not allow for list of Columns"):
+            lit([df.id, df.id])
+
     # Test added for SPARK-39832; change Python API to accept both col & str as input
     def test_regexp_replace(self):
         df = self.spark.createDataFrame(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org