You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/26 16:42:32 UTC

[spark] branch master updated: [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e98987220ae [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`
e98987220ae is described below

commit e98987220ae191ecc10944026fee9c57ddf478c1
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Mon Jun 26 19:42:17 2023 +0300

    [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to extend the `sql()` method of Python connect client, and support positional parameters as list of Python objects that can be converted to literal expressions.
    
    ```python
    def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> DataFrame:
    ```
    
    where
    
    - **args** is a dictionary of parameter names to Python objects or a list of Python objects that can be converted to SQL literal expressions. See the [link](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) regarding the supported value types in PySpark. For example: _1, "Steven", datetime.date(2023, 4, 2)_. The same as in Scala/Java API, a value can be also a `Column` of literal expression, in that case it is taken as is.
    
    For example:
    ```python
         >>> connect.sql("SELECT * FROM {df} WHERE {df[B]} > ? and ? < {df[A]}", [5, 2], df=mydf).show()
         +---+---+
         |  A|  B|
         +---+---+
         |  3|  6|
         +---+---+
    ```
    
    ### Why are the changes needed?
    To achieve feature parity with the PySpark API.
    
    ### Does this PR introduce _any_ user-facing change?
    No, the PR just extends the existing API.
    
    ### How was this patch tested?
    By running new test:
    ```
    $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_pos_args'
    ```
    and the renamed test:
    ```
    $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_named_args'
    ```
    
    Closes #41739 from MaxGekk/positional-params-python-connect.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 python/pyspark/sql/connect/plan.py                 | 36 ++++++++++++++++------
 python/pyspark/sql/connect/session.py              |  2 +-
 .../sql/tests/connect/test_connect_basic.py        |  7 ++++-
 3 files changed, 34 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 406f65080d1..fabab98d9b2 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1019,12 +1019,15 @@ class SubqueryAlias(LogicalPlan):
 
 
 class SQL(LogicalPlan):
-    def __init__(self, query: str, args: Optional[Dict[str, Any]] = None) -> None:
+    def __init__(self, query: str, args: Optional[Union[Dict[str, Any], List]] = None) -> None:
         super().__init__(None)
 
         if args is not None:
-            for k, v in args.items():
-                assert isinstance(k, str)
+            if isinstance(args, Dict):
+                for k, v in args.items():
+                    assert isinstance(k, str)
+            else:
+                assert isinstance(args, List)
 
         self._query = query
         self._args = args
@@ -1034,8 +1037,16 @@ class SQL(LogicalPlan):
         plan.sql.query = self._query
 
         if self._args is not None and len(self._args) > 0:
-            for k, v in self._args.items():
-                plan.sql.args[k].CopyFrom(LiteralExpression._from_value(v).to_plan(session).literal)
+            if isinstance(self._args, Dict):
+                for k, v in self._args.items():
+                    plan.sql.args[k].CopyFrom(
+                        LiteralExpression._from_value(v).to_plan(session).literal
+                    )
+            else:
+                for v in self._args:
+                    plan.sql.pos_args.append(
+                        LiteralExpression._from_value(v).to_plan(session).literal
+                    )
 
         return plan
 
@@ -1043,10 +1054,17 @@ class SQL(LogicalPlan):
         cmd = proto.Command()
         cmd.sql_command.sql = self._query
         if self._args is not None and len(self._args) > 0:
-            for k, v in self._args.items():
-                cmd.sql_command.args[k].CopyFrom(
-                    LiteralExpression._from_value(v).to_plan(session).literal
-                )
+            if isinstance(self._args, Dict):
+                for k, v in self._args.items():
+                    cmd.sql_command.args[k].CopyFrom(
+                        LiteralExpression._from_value(v).to_plan(session).literal
+                    )
+            else:
+                for v in self._args:
+                    cmd.sql_command.pos_args.append(
+                        LiteralExpression._from_value(v).to_plan(session).literal
+                    )
+
         return cmd
 
 
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 365829ff7bc..356dacd8e18 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,7 +489,7 @@ class SparkSession:
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
-    def sql(self, sqlQuery: str, args: Optional[Dict[str, Any]] = None) -> "DataFrame":
+    def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame":
         cmd = SQL(sqlQuery, args)
         data, properties = self.client.execute_command(cmd.command(self._client))
         if "sql_command_result" in properties:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 89384b24e45..268011ef1e4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1223,11 +1223,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         pdf = self.connect.sql("SELECT 1").toPandas()
         self.assertEqual(1, len(pdf.index))
 
-    def test_sql_with_args(self):
+    def test_sql_with_named_args(self):
         df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
         df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
         self.assert_eq(df.toPandas(), df2.toPandas())
 
+    def test_sql_with_pos_args(self):
+        df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+        df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+        self.assert_eq(df.toPandas(), df2.toPandas())
+
     def test_head(self):
         # SPARK-41002: test `head` API in Python Client
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org