You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2024/01/02 07:20:22 UTC
(spark) branch master updated: [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 48a09c457cf5 [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects
48a09c457cf5 is described below
commit 48a09c457cf5854d956138d3881d2c45e15b291d
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Jan 2 16:20:11 2024 +0900
[SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects
### What changes were proposed in this pull request?
This PR fixes an issue when the `read` method of Python DataSourceReader yields named `Row` objects.
Currently, it ignores the name in the Row object:
```Python
def read(self,...):
yield Row(a=1, b=2)
yield Row(b=3, a=2)
```
The result should be `[Row(a=1, b=2), Row(a=2, b=3)]`, instead of `[Row(a=1 , b=2), Row(a=3, b=2)]`.
### Why are the changes needed?
To fix an incorrect behavior.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44531 from allisonwang-db/spark-46540-named-rows.
Authored-by: allisonwang-db <al...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/tests/test_python_datasource.py | 16 +++++++++++++
python/pyspark/sql/worker/plan_data_source_read.py | 24 ++++++++++++++++++--
.../execution/python/PythonDataSourceSuite.scala | 26 ++++++++++++++++++++++
3 files changed, 64 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index 32333a8ccee9..8517d8f36382 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -135,6 +135,22 @@ class BasePythonDataSourceTestsMixin:
df = self.spark.read.format("test").load()
assertDataFrameEqual(df, [Row(0, 1)])
+ def test_data_source_read_output_named_row(self):
+ self.register_data_source(
+ read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)])
+ )
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)])
+
+ def test_data_source_read_output_named_row_with_wrong_schema(self):
+ self.register_data_source(
+ read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)])
+ )
+ with self.assertRaisesRegex(
+ PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+ ):
+ self.spark.read.format("test").load().show()
+
def test_data_source_read_output_none(self):
self.register_data_source(read_func=lambda schema, partition: None)
df = self.spark.read.format("test").load()
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index d2fcb5096ae2..d4693f5ff7be 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -29,6 +29,7 @@ from pyspark.serializers import (
write_int,
SpecialLengths,
)
+from pyspark.sql import Row
from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
from pyspark.sql.datasource import DataSource, InputPartition
from pyspark.sql.pandas.types import to_arrow_schema
@@ -234,6 +235,8 @@ def main(infile: IO, outfile: IO) -> None:
# Convert the results from the `reader.read` method to an iterator of arrow batches.
num_cols = len(column_names)
+ col_mapping = {name: i for i, name in enumerate(column_names)}
+ col_name_set = set(column_names)
for batch in batched(output_iter, max_arrow_batch_size):
pylist: List[List] = [[] for _ in range(num_cols)]
for result in batch:
@@ -258,8 +261,25 @@ def main(infile: IO, outfile: IO) -> None:
},
)
- for col in range(num_cols):
- pylist[col].append(column_converters[col](result[col]))
+ # Assign output values by name of the field, not position, if the result is a
+ # named `Row` object.
+ if isinstance(result, Row) and hasattr(result, "__fields__"):
+ # Check if the names are the same as the schema.
+ if set(result.__fields__) != col_name_set:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+ message_parameters={
+ "expected": str(column_names),
+ "actual": str(result.__fields__),
+ },
+ )
+ # Assign the values by name.
+ for name in column_names:
+ idx = col_mapping[name]
+ pylist[idx].append(column_converters[idx](result[name]))
+ else:
+ for col in range(num_cols):
+ pylist[col].append(column_converters[col](result[col]))
yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 080f57aa08a0..49fb2e859fff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -455,6 +455,32 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
}
}
+ test("SPARK-46540: data source read output named rows") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |class SimpleDataSourceReader(DataSourceReader):
+ | def read(self, partition):
+ | from pyspark.sql import Row
+ | yield Row(x = 0, y = 1)
+ | yield Row(y = 2, x = 1)
+ | yield Row(2, 3)
+ | yield (3, 4)
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "x int, y int"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
+ checkAnswer(df, Seq(Row(0, 1), Row(1, 2), Row(2, 3), Row(3, 4)))
+ }
+
test("SPARK-46424: Support Python metrics") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org