You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/29 00:59:06 UTC
[spark] branch master updated: [SPARK-42907][CONNECT][PYTHON] Implement Avro functions
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5a56c172831 [SPARK-42907][CONNECT][PYTHON] Implement Avro functions
5a56c172831 is described below
commit 5a56c17283103821714ffaaf1c764e05d0ff6b58
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Mar 29 09:58:52 2023 +0900
[SPARK-42907][CONNECT][PYTHON] Implement Avro functions
### What changes were proposed in this pull request?
Implement Avro functions
### Why are the changes needed?
For function parity
### Does this PR introduce _any_ user-facing change?
yes, new APIs
### How was this patch tested?
added doctest and manually check
```
(spark_dev) ➜ spark git:(connect_avro_functions) ✗ bin/pyspark --remote "local[*]" --jars connector/avro/target/scala-2.12/spark-avro_2.12-3.5.0-SNAPSHOT.jar
Python 3.9.16 (main, Mar 8 2023, 04:29:24)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help.
23/03/23 16:28:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 3.5.0.dev0
/_/
Using Python version 3.9.16 (main, Mar 8 2023 04:29:24)
Client connected to the Spark Connect server at localhost
SparkSession available as 'spark'.
In [1]: >>> from pyspark.sql import Row
...: >>> from pyspark.sql.avro.functions import from_avro, to_avro
...: >>> data = [(1, Row(age=2, name='Alice'))]
...: >>> df = spark.createDataFrame(data, ("key", "value"))
...: >>> avroDf = df.select(to_avro(df.value).alias("avro"))
In [2]: avroDf.collect()
Out[2]: [Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))]
```
Closes #40535 from zhengruifeng/connect_avro_functions.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
assembly/pom.xml | 6 ++
.../apache/spark/sql/avro/AvroDataToCatalyst.scala | 2 +-
.../apache/spark/sql/avro/CatalystDataToAvro.scala | 2 +-
connector/connect/server/pom.xml | 6 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 35 +++++++
dev/sparktestsupport/modules.py | 3 +-
python/pyspark/sql/avro/functions.py | 10 +-
python/pyspark/sql/connect/avro/__init__.py | 18 ++++
python/pyspark/sql/connect/avro/functions.py | 114 +++++++++++++++++++++
python/pyspark/sql/utils.py | 16 +++
10 files changed, 208 insertions(+), 4 deletions(-)
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 36cc6078438..09d6bd8a33f 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -160,6 +160,12 @@
<artifactId>spark-connect_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-avro_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
index c4a4b16b052..f8718edd97f 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.types._
-private[avro] case class AvroDataToCatalyst(
+private[sql] case class AvroDataToCatalyst(
child: Expression,
jsonFormatSchema: String,
options: Map[String, String])
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
index 1e7e8600977..56ed117aef5 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}
-private[avro] case class CatalystDataToAvro(
+private[sql] case class CatalystDataToAvro(
child: Expression,
jsonFormatSchema: Option[String]) extends UnaryExpression {
diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml
index 838d7bf2bd3..a62c420bcc0 100644
--- a/connector/connect/server/pom.xml
+++ b/connector/connect/server/pom.xml
@@ -105,6 +105,12 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-avro_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7e88cab643..d5baca9e17f 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -1256,6 +1257,40 @@ class SparkConnectPlanner(val session: SparkSession) {
None
}
+ // Avro-specific functions
+ case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
+ val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
+ val jsonFormatSchema = children(1) match {
+ case Literal(s, StringType) if s != null => s.toString
+ case other =>
+ throw InvalidPlanInput(
+ s"jsonFormatSchema in from_avro should be a literal string, but got $other")
+ }
+ var options = Map.empty[String, String]
+ if (fun.getArgumentsCount == 3) {
+ children(2) match {
+ case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
+ options = ExprUtils.convertToMapData(CreateMap(arguments))
+ case other =>
+ throw InvalidPlanInput(
+ s"Options in from_json should be created by map, but got $other")
+ }
+ }
+ Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))
+
+ case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) =>
+ val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
+ var jsonFormatSchema = Option.empty[String]
+ if (fun.getArgumentsCount == 2) {
+ children(1) match {
+ case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString)
+ case other =>
+ throw InvalidPlanInput(
+ s"jsonFormatSchema in to_avro should be a literal string, but got $other")
+ }
+ }
+ Some(CatalystDataToAvro(children.head, jsonFormatSchema))
+
// PS(Pandas API on Spark)-specific functions
case "distributed_sequence_id" if fun.getArgumentsCount == 0 =>
Some(DistributedSequenceID())
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 11257841bce..f65ef7e3ac0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -273,7 +273,7 @@ sql_kafka = Module(
connect = Module(
name="connect",
- dependencies=[hive],
+ dependencies=[hive, avro],
source_file_regexes=[
"connector/connect",
],
@@ -748,6 +748,7 @@ pyspark_connect = Module(
"pyspark.sql.connect.readwriter",
"pyspark.sql.connect.dataframe",
"pyspark.sql.connect.functions",
+ "pyspark.sql.connect.avro.functions",
# sql unittests
"pyspark.sql.tests.connect.test_client",
"pyspark.sql.tests.connect.test_connect_plan",
diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py
index 080e45934e6..e49953e8953 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -25,13 +25,14 @@ from typing import Dict, Optional, TYPE_CHECKING, cast
from py4j.java_gateway import JVMView
from pyspark.sql.column import Column, _to_java_column
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions
from pyspark.util import _print_missing_jar
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName
+@try_remote_avro_functions
def from_avro(
data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
) -> Column:
@@ -44,6 +45,9 @@ def from_avro(
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.5.0
+ Supports Spark Connect.
+
Parameters
----------
data : :class:`~pyspark.sql.Column` or str
@@ -88,12 +92,16 @@ def from_avro(
return Column(jc)
+@try_remote_avro_functions
def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
"""
Converts a column into binary of avro format.
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.5.0
+ Supports Spark Connect.
+
Parameters
----------
data : :class:`~pyspark.sql.Column` or str
diff --git a/python/pyspark/sql/connect/avro/__init__.py b/python/pyspark/sql/connect/avro/__init__.py
new file mode 100644
index 00000000000..6d29d44cb9c
--- /dev/null
+++ b/python/pyspark/sql/connect/avro/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Python Client - Avro Functions"""
diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py
new file mode 100644
index 00000000000..acd7fa63054
--- /dev/null
+++ b/python/pyspark/sql/connect/avro/functions.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin avro functions
+"""
+
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+from typing import Dict, Optional, TYPE_CHECKING
+
+from pyspark.sql.avro import functions as PyAvroFunctions
+
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import ColumnOrName
+
+
+def from_avro(
+ data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
+) -> Column:
+ if options is None:
+ return _invoke_function("from_avro", _to_col(data), lit(jsonFormatSchema))
+ else:
+ return _invoke_function(
+ "from_avro", _to_col(data), lit(jsonFormatSchema), _options_to_col(options)
+ )
+
+
+from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__
+
+
+def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
+ if jsonFormatSchema == "":
+ return _invoke_function("to_avro", _to_col(data))
+ else:
+ return _invoke_function("to_avro", _to_col(data), lit(jsonFormatSchema))
+
+
+to_avro.__doc__ = PyAvroFunctions.to_avro.__doc__
+
+
+def _test() -> None:
+ import os
+ import sys
+ from pyspark.testing.utils import search_jar
+
+ avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro")
+
+ print()
+ print(avro_jar)
+ print(avro_jar)
+ print(avro_jar)
+ print()
+
+ if avro_jar is None:
+ print(
+ "Skipping all Avro Python tests as the optional Avro project was "
+ "not compiled into a JAR. To run these tests, "
+ "you need to build Spark with 'build/sbt -Pavro package' or "
+ "'build/mvn -Pavro package' before running this test."
+ )
+ sys.exit(0)
+ else:
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % avro_jar
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
+
+ import doctest
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.sql.connect.avro.functions
+
+ globs = pyspark.sql.connect.avro.functions.__dict__.copy()
+
+ globs["spark"] = (
+ PySparkSession.builder.appName("sql.connect.avro.functions tests")
+ .remote("local[4]")
+ .getOrCreate()
+ )
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.connect.avro.functions,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS
+ | doctest.NORMALIZE_WHITESPACE
+ | doctest.IGNORE_EXCEPTION_DETAIL,
+ )
+
+ globs["spark"].stop()
+
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b5d17e38b87..6f75325e0d8 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -161,6 +161,22 @@ def try_remote_functions(f: FuncT) -> FuncT:
return cast(FuncT, wrapped)
+def try_remote_avro_functions(f: FuncT) -> FuncT:
+ """Mark API supported from Spark Connect."""
+
+ @functools.wraps(f)
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ from pyspark.sql.connect.avro import functions
+
+ return getattr(functions, f.__name__)(*args, **kwargs)
+ else:
+ return f(*args, **kwargs)
+
+ return cast(FuncT, wrapped)
+
+
def try_remote_window(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org