You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/26 16:42:32 UTC
[spark] branch master updated: [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e98987220ae [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`
e98987220ae is described below
commit e98987220ae191ecc10944026fee9c57ddf478c1
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Mon Jun 26 19:42:17 2023 +0300
[SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()`
### What changes were proposed in this pull request?
In the PR, I propose to extend the `sql()` method of Python connect client, and support positional parameters as list of Python objects that can be converted to literal expressions.
```python
def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> DataFrame:
```
where
- **args** is a dictionary of parameter names to Python objects or a list of Python objects that can be converted to SQL literal expressions. See the [link](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) regarding the supported value types in PySpark. For example: _1, "Steven", datetime.date(2023, 4, 2)_. The same as in Scala/Java API, a value can be also a `Column` of literal expression, in that case it is taken as is.
For example:
```python
>>> connect.sql("SELECT * FROM {df} WHERE {df[B]} > ? and ? < {df[A]}", [5, 2], df=mydf).show()
+---+---+
| A| B|
+---+---+
| 3| 6|
+---+---+
```
### Why are the changes needed?
To achieve feature parity with the PySpark API.
### Does this PR introduce _any_ user-facing change?
No, the PR just extends the existing API.
### How was this patch tested?
By running new test:
```
$ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_pos_args'
```
and the renamed test:
```
$ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_named_args'
```
Closes #41739 from MaxGekk/positional-params-python-connect.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
python/pyspark/sql/connect/plan.py | 36 ++++++++++++++++------
python/pyspark/sql/connect/session.py | 2 +-
.../sql/tests/connect/test_connect_basic.py | 7 ++++-
3 files changed, 34 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 406f65080d1..fabab98d9b2 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1019,12 +1019,15 @@ class SubqueryAlias(LogicalPlan):
class SQL(LogicalPlan):
- def __init__(self, query: str, args: Optional[Dict[str, Any]] = None) -> None:
+ def __init__(self, query: str, args: Optional[Union[Dict[str, Any], List]] = None) -> None:
super().__init__(None)
if args is not None:
- for k, v in args.items():
- assert isinstance(k, str)
+ if isinstance(args, Dict):
+ for k, v in args.items():
+ assert isinstance(k, str)
+ else:
+ assert isinstance(args, List)
self._query = query
self._args = args
@@ -1034,8 +1037,16 @@ class SQL(LogicalPlan):
plan.sql.query = self._query
if self._args is not None and len(self._args) > 0:
- for k, v in self._args.items():
- plan.sql.args[k].CopyFrom(LiteralExpression._from_value(v).to_plan(session).literal)
+ if isinstance(self._args, Dict):
+ for k, v in self._args.items():
+ plan.sql.args[k].CopyFrom(
+ LiteralExpression._from_value(v).to_plan(session).literal
+ )
+ else:
+ for v in self._args:
+ plan.sql.pos_args.append(
+ LiteralExpression._from_value(v).to_plan(session).literal
+ )
return plan
@@ -1043,10 +1054,17 @@ class SQL(LogicalPlan):
cmd = proto.Command()
cmd.sql_command.sql = self._query
if self._args is not None and len(self._args) > 0:
- for k, v in self._args.items():
- cmd.sql_command.args[k].CopyFrom(
- LiteralExpression._from_value(v).to_plan(session).literal
- )
+ if isinstance(self._args, Dict):
+ for k, v in self._args.items():
+ cmd.sql_command.args[k].CopyFrom(
+ LiteralExpression._from_value(v).to_plan(session).literal
+ )
+ else:
+ for v in self._args:
+ cmd.sql_command.pos_args.append(
+ LiteralExpression._from_value(v).to_plan(session).literal
+ )
+
return cmd
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 365829ff7bc..356dacd8e18 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,7 +489,7 @@ class SparkSession:
createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
- def sql(self, sqlQuery: str, args: Optional[Dict[str, Any]] = None) -> "DataFrame":
+ def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame":
cmd = SQL(sqlQuery, args)
data, properties = self.client.execute_command(cmd.command(self._client))
if "sql_command_result" in properties:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 89384b24e45..268011ef1e4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1223,11 +1223,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
pdf = self.connect.sql("SELECT 1").toPandas()
self.assertEqual(1, len(pdf.index))
- def test_sql_with_args(self):
+ def test_sql_with_named_args(self):
df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
self.assert_eq(df.toPandas(), df2.toPandas())
+ def test_sql_with_pos_args(self):
+ df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+ df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+ self.assert_eq(df.toPandas(), df2.toPandas())
+
def test_head(self):
# SPARK-41002: test `head` API in Python Client
df = self.connect.read.table(self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org