You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/29 00:59:06 UTC

[spark] branch master updated: [SPARK-42907][CONNECT][PYTHON] Implement Avro functions

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a56c172831 [SPARK-42907][CONNECT][PYTHON] Implement Avro functions
5a56c172831 is described below

commit 5a56c17283103821714ffaaf1c764e05d0ff6b58
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Mar 29 09:58:52 2023 +0900

    [SPARK-42907][CONNECT][PYTHON] Implement Avro functions
    
    ### What changes were proposed in this pull request?
     Implement Avro functions
    
    ### Why are the changes needed?
    For function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new APIs
    
    ### How was this patch tested?
    added doctest and manually check
    ```
    (spark_dev) ➜  spark git:(connect_avro_functions) ✗ bin/pyspark --remote "local[*]" --jars connector/avro/target/scala-2.12/spark-avro_2.12-3.5.0-SNAPSHOT.jar
    Python 3.9.16 (main, Mar  8 2023, 04:29:24)
    Type 'copyright', 'credits' or 'license' for more information
    IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help.
    23/03/23 16:28:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
    Setting default log level to "WARN".
    To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
    Welcome to
          ____              __
         / __/__  ___ _____/ /__
        _\ \/ _ \/ _ `/ __/  '_/
       /__ / .__/\_,_/_/ /_/\_\   version 3.5.0.dev0
          /_/
    
    Using Python version 3.9.16 (main, Mar  8 2023 04:29:24)
    Client connected to the Spark Connect server at localhost
    SparkSession available as 'spark'.
    
    In [1]:     >>> from pyspark.sql import Row
       ...:     >>> from pyspark.sql.avro.functions import from_avro, to_avro
       ...:     >>> data = [(1, Row(age=2, name='Alice'))]
       ...:     >>> df = spark.createDataFrame(data, ("key", "value"))
       ...:     >>> avroDf = df.select(to_avro(df.value).alias("avro"))
    
    In [2]: avroDf.collect()
    Out[2]: [Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))]
    ```
    
    Closes #40535 from zhengruifeng/connect_avro_functions.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 assembly/pom.xml                                   |   6 ++
 .../apache/spark/sql/avro/AvroDataToCatalyst.scala |   2 +-
 .../apache/spark/sql/avro/CatalystDataToAvro.scala |   2 +-
 connector/connect/server/pom.xml                   |   6 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  35 +++++++
 dev/sparktestsupport/modules.py                    |   3 +-
 python/pyspark/sql/avro/functions.py               |  10 +-
 python/pyspark/sql/connect/avro/__init__.py        |  18 ++++
 python/pyspark/sql/connect/avro/functions.py       | 114 +++++++++++++++++++++
 python/pyspark/sql/utils.py                        |  16 +++
 10 files changed, 208 insertions(+), 4 deletions(-)

diff --git a/assembly/pom.xml b/assembly/pom.xml
index 36cc6078438..09d6bd8a33f 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -160,6 +160,12 @@
           <artifactId>spark-connect_${scala.binary.version}</artifactId>
           <version>${project.version}</version>
         </dependency>
+        <dependency>
+          <groupId>org.apache.spark</groupId>
+          <artifactId>spark-avro_${scala.binary.version}</artifactId>
+          <version>${project.version}</version>
+          <scope>provided</scope>
+        </dependency>
       </dependencies>
     </profile>
     <profile>
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
index c4a4b16b052..f8718edd97f 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
 import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
 import org.apache.spark.sql.types._
 
-private[avro] case class AvroDataToCatalyst(
+private[sql] case class AvroDataToCatalyst(
     child: Expression,
     jsonFormatSchema: String,
     options: Map[String, String])
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
index 1e7e8600977..56ed117aef5 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.types.{BinaryType, DataType}
 
-private[avro] case class CatalystDataToAvro(
+private[sql] case class CatalystDataToAvro(
     child: Expression,
     jsonFormatSchema: Option[String]) extends UnaryExpression {
 
diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml
index 838d7bf2bd3..a62c420bcc0 100644
--- a/connector/connect/server/pom.xml
+++ b/connector/connect/server/pom.xml
@@ -105,6 +105,12 @@
         </exclusion>
       </exclusions>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-avro_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <scope>provided</scope>
+    </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
       <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7e88cab643..d5baca9e17f 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.connect.proto.Parse.ParseFormat
 import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -1256,6 +1257,40 @@ class SparkConnectPlanner(val session: SparkSession) {
           None
         }
 
+      // Avro-specific functions
+      case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
+        val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
+        val jsonFormatSchema = children(1) match {
+          case Literal(s, StringType) if s != null => s.toString
+          case other =>
+            throw InvalidPlanInput(
+              s"jsonFormatSchema in from_avro should be a literal string, but got $other")
+        }
+        var options = Map.empty[String, String]
+        if (fun.getArgumentsCount == 3) {
+          children(2) match {
+            case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
+              options = ExprUtils.convertToMapData(CreateMap(arguments))
+            case other =>
+              throw InvalidPlanInput(
+                s"Options in from_json should be created by map, but got $other")
+          }
+        }
+        Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))
+
+      case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) =>
+        val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
+        var jsonFormatSchema = Option.empty[String]
+        if (fun.getArgumentsCount == 2) {
+          children(1) match {
+            case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString)
+            case other =>
+              throw InvalidPlanInput(
+                s"jsonFormatSchema in to_avro should be a literal string, but got $other")
+          }
+        }
+        Some(CatalystDataToAvro(children.head, jsonFormatSchema))
+
       // PS(Pandas API on Spark)-specific functions
       case "distributed_sequence_id" if fun.getArgumentsCount == 0 =>
         Some(DistributedSequenceID())
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 11257841bce..f65ef7e3ac0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -273,7 +273,7 @@ sql_kafka = Module(
 
 connect = Module(
     name="connect",
-    dependencies=[hive],
+    dependencies=[hive, avro],
     source_file_regexes=[
         "connector/connect",
     ],
@@ -748,6 +748,7 @@ pyspark_connect = Module(
         "pyspark.sql.connect.readwriter",
         "pyspark.sql.connect.dataframe",
         "pyspark.sql.connect.functions",
+        "pyspark.sql.connect.avro.functions",
         # sql unittests
         "pyspark.sql.tests.connect.test_client",
         "pyspark.sql.tests.connect.test_connect_plan",
diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py
index 080e45934e6..e49953e8953 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -25,13 +25,14 @@ from typing import Dict, Optional, TYPE_CHECKING, cast
 from py4j.java_gateway import JVMView
 
 from pyspark.sql.column import Column, _to_java_column
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions
 from pyspark.util import _print_missing_jar
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName
 
 
+@try_remote_avro_functions
 def from_avro(
     data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
 ) -> Column:
@@ -44,6 +45,9 @@ def from_avro(
 
     .. versionadded:: 3.0.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect.
+
     Parameters
     ----------
     data : :class:`~pyspark.sql.Column` or str
@@ -88,12 +92,16 @@ def from_avro(
     return Column(jc)
 
 
+@try_remote_avro_functions
 def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
     """
     Converts a column into binary of avro format.
 
     .. versionadded:: 3.0.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect.
+
     Parameters
     ----------
     data : :class:`~pyspark.sql.Column` or str
diff --git a/python/pyspark/sql/connect/avro/__init__.py b/python/pyspark/sql/connect/avro/__init__.py
new file mode 100644
index 00000000000..6d29d44cb9c
--- /dev/null
+++ b/python/pyspark/sql/connect/avro/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Python Client - Avro Functions"""
diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py
new file mode 100644
index 00000000000..acd7fa63054
--- /dev/null
+++ b/python/pyspark/sql/connect/avro/functions.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin avro functions
+"""
+
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+from typing import Dict, Optional, TYPE_CHECKING
+
+from pyspark.sql.avro import functions as PyAvroFunctions
+
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import ColumnOrName
+
+
+def from_avro(
+    data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
+) -> Column:
+    if options is None:
+        return _invoke_function("from_avro", _to_col(data), lit(jsonFormatSchema))
+    else:
+        return _invoke_function(
+            "from_avro", _to_col(data), lit(jsonFormatSchema), _options_to_col(options)
+        )
+
+
+from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__
+
+
+def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
+    if jsonFormatSchema == "":
+        return _invoke_function("to_avro", _to_col(data))
+    else:
+        return _invoke_function("to_avro", _to_col(data), lit(jsonFormatSchema))
+
+
+to_avro.__doc__ = PyAvroFunctions.to_avro.__doc__
+
+
+def _test() -> None:
+    import os
+    import sys
+    from pyspark.testing.utils import search_jar
+
+    avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro")
+
+    print()
+    print(avro_jar)
+    print(avro_jar)
+    print(avro_jar)
+    print()
+
+    if avro_jar is None:
+        print(
+            "Skipping all Avro Python tests as the optional Avro project was "
+            "not compiled into a JAR. To run these tests, "
+            "you need to build Spark with 'build/sbt -Pavro package' or "
+            "'build/mvn -Pavro package' before running this test."
+        )
+        sys.exit(0)
+    else:
+        existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+        jars_args = "--jars %s" % avro_jar
+        os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
+
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.sql.connect.avro.functions
+
+    globs = pyspark.sql.connect.avro.functions.__dict__.copy()
+
+    globs["spark"] = (
+        PySparkSession.builder.appName("sql.connect.avro.functions tests")
+        .remote("local[4]")
+        .getOrCreate()
+    )
+
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.connect.avro.functions,
+        globs=globs,
+        optionflags=doctest.ELLIPSIS
+        | doctest.NORMALIZE_WHITESPACE
+        | doctest.IGNORE_EXCEPTION_DETAIL,
+    )
+
+    globs["spark"].stop()
+
+    if failure_count:
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b5d17e38b87..6f75325e0d8 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -161,6 +161,22 @@ def try_remote_functions(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def try_remote_avro_functions(f: FuncT) -> FuncT:
+    """Mark API supported from Spark Connect."""
+
+    @functools.wraps(f)
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            from pyspark.sql.connect.avro import functions
+
+            return getattr(functions, f.__name__)(*args, **kwargs)
+        else:
+            return f(*args, **kwargs)
+
+    return cast(FuncT, wrapped)
+
+
 def try_remote_window(f: FuncT) -> FuncT:
     """Mark API supported from Spark Connect."""
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org