You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2024/01/02 07:20:22 UTC

(spark) branch master updated: [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 48a09c457cf5 [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects
48a09c457cf5 is described below

commit 48a09c457cf5854d956138d3881d2c45e15b291d
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Jan 2 16:20:11 2024 +0900

    [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an issue when the `read` method of Python DataSourceReader yields named `Row` objects.
    Currently, it ignores the name in the Row object:
    ```Python
    def read(self,...):
        yield Row(a=1, b=2)
        yield Row(b=3, a=2)
    ```
    The result should be `[Row(a=1, b=2), Row(a=2, b=3)]`, instead of `[Row(a=1 , b=2), Row(a=3, b=2)]`.
    
    ### Why are the changes needed?
    
    To fix an incorrect behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44531 from allisonwang-db/spark-46540-named-rows.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/tests/test_python_datasource.py | 16 +++++++++++++
 python/pyspark/sql/worker/plan_data_source_read.py | 24 ++++++++++++++++++--
 .../execution/python/PythonDataSourceSuite.scala   | 26 ++++++++++++++++++++++
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index 32333a8ccee9..8517d8f36382 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -135,6 +135,22 @@ class BasePythonDataSourceTestsMixin:
         df = self.spark.read.format("test").load()
         assertDataFrameEqual(df, [Row(0, 1)])
 
+    def test_data_source_read_output_named_row(self):
+        self.register_data_source(
+            read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)])
+        )
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)])
+
+    def test_data_source_read_output_named_row_with_wrong_schema(self):
+        self.register_data_source(
+            read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)])
+        )
+        with self.assertRaisesRegex(
+            PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+        ):
+            self.spark.read.format("test").load().show()
+
     def test_data_source_read_output_none(self):
         self.register_data_source(read_func=lambda schema, partition: None)
         df = self.spark.read.format("test").load()
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index d2fcb5096ae2..d4693f5ff7be 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -29,6 +29,7 @@ from pyspark.serializers import (
     write_int,
     SpecialLengths,
 )
+from pyspark.sql import Row
 from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
 from pyspark.sql.datasource import DataSource, InputPartition
 from pyspark.sql.pandas.types import to_arrow_schema
@@ -234,6 +235,8 @@ def main(infile: IO, outfile: IO) -> None:
 
             # Convert the results from the `reader.read` method to an iterator of arrow batches.
             num_cols = len(column_names)
+            col_mapping = {name: i for i, name in enumerate(column_names)}
+            col_name_set = set(column_names)
             for batch in batched(output_iter, max_arrow_batch_size):
                 pylist: List[List] = [[] for _ in range(num_cols)]
                 for result in batch:
@@ -258,8 +261,25 @@ def main(infile: IO, outfile: IO) -> None:
                             },
                         )
 
-                    for col in range(num_cols):
-                        pylist[col].append(column_converters[col](result[col]))
+                    # Assign output values by name of the field, not position, if the result is a
+                    # named `Row` object.
+                    if isinstance(result, Row) and hasattr(result, "__fields__"):
+                        # Check if the names are the same as the schema.
+                        if set(result.__fields__) != col_name_set:
+                            raise PySparkRuntimeError(
+                                error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+                                message_parameters={
+                                    "expected": str(column_names),
+                                    "actual": str(result.__fields__),
+                                },
+                            )
+                        # Assign the values by name.
+                        for name in column_names:
+                            idx = col_mapping[name]
+                            pylist[idx].append(column_converters[idx](result[name]))
+                    else:
+                        for col in range(num_cols):
+                            pylist[col].append(column_converters[col](result[col]))
 
                 yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 080f57aa08a0..49fb2e859fff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -455,6 +455,32 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
     }
   }
 
+  test("SPARK-46540: data source read output named rows") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def read(self, partition):
+         |        from pyspark.sql import Row
+         |        yield Row(x = 0, y = 1)
+         |        yield Row(y = 2, x = 1)
+         |        yield Row(2, 3)
+         |        yield (3, 4)
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "x int, y int"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).load()
+    checkAnswer(df, Seq(Row(0, 1), Row(1, 2), Row(2, 3), Row(3, 4)))
+  }
+
   test("SPARK-46424: Support Python metrics") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org