You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/25 01:32:57 UTC
[spark] branch master updated: [SPARK-41834][CONNECT] Implement SparkSession.conf
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 47951c9ab98 [SPARK-41834][CONNECT] Implement SparkSession.conf
47951c9ab98 is described below
commit 47951c9ab98523665530b291218073c885183184
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Sat Feb 25 10:32:43 2023 +0900
[SPARK-41834][CONNECT] Implement SparkSession.conf
### What changes were proposed in this pull request?
Implements `SparkSession.conf`.
Took #39995 over.
### Why are the changes needed?
`SparkSession.conf` is a missing feature.
### Does this PR introduce _any_ user-facing change?
Yes, `SparkSession.conf` will be available.
### How was this patch tested?
Added/enabled related tests.
Closes #40150 from ueshin/issues/SPARK-41834/conf.
Lead-authored-by: Takuya UESHIN <ue...@databricks.com>
Co-authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../src/main/protobuf/spark/connect/base.proto | 94 ++++++
.../service/SparkConnectConfigHandler.scala | 181 +++++++++++
.../sql/connect/service/SparkConnectService.scala | 15 +
dev/sparktestsupport/modules.py | 2 +
python/pyspark/pandas/utils.py | 2 +-
python/pyspark/sql/conf.py | 27 +-
python/pyspark/sql/connect/client.py | 52 ++++
python/pyspark/sql/connect/conf.py | 125 ++++++++
python/pyspark/sql/connect/functions.py | 5 -
python/pyspark/sql/connect/proto/base_pb2.py | 152 ++++++++-
python/pyspark/sql/connect/proto/base_pb2.pyi | 342 +++++++++++++++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 45 +++
python/pyspark/sql/connect/session.py | 5 +-
python/pyspark/sql/context.py | 6 +-
.../sql/tests/connect/test_connect_basic.py | 1 -
.../pyspark/sql/tests/connect/test_parity_conf.py | 36 +++
.../sql/tests/connect/test_parity_dataframe.py | 20 --
python/pyspark/sql/tests/test_conf.py | 29 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 14 +-
19 files changed, 1108 insertions(+), 45 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 5f9a4411ecd..1ffbb8aa881 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -183,6 +183,97 @@ message ExecutePlanResponse {
}
}
+// The key-value pair for the config request and response.
+message KeyValue {
+ // (Required) The key.
+ string key = 1;
+ // (Optional) The value.
+ optional string value = 2;
+}
+
+// Request to update or fetch the configurations.
+message ConfigRequest {
+ // (Required)
+ //
+ // The client_id is set by the client to be able to collate streaming responses from
+ // different queries.
+ string client_id = 1;
+
+ // (Required) User context
+ UserContext user_context = 2;
+
+ // (Required) The operation for the config.
+ Operation operation = 3;
+
+ // Provides optional information about the client sending the request. This field
+ // can be used for language or version specific information and is only intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 4;
+
+ message Operation {
+ oneof op_type {
+ Set set = 1;
+ Get get = 2;
+ GetWithDefault get_with_default = 3;
+ GetOption get_option = 4;
+ GetAll get_all = 5;
+ Unset unset = 6;
+ IsModifiable is_modifiable = 7;
+ }
+ }
+
+ message Set {
+ // (Required) The config key-value pairs to set.
+ repeated KeyValue pairs = 1;
+ }
+
+ message Get {
+ // (Required) The config keys to get.
+ repeated string keys = 1;
+ }
+
+ message GetWithDefault {
+ // (Required) The config key-value paris to get. The value will be used as the default value.
+ repeated KeyValue pairs = 1;
+ }
+
+ message GetOption {
+ // (Required) The config keys to get optionally.
+ repeated string keys = 1;
+ }
+
+ message GetAll {
+ // (Optional) The prefix of the config key to get.
+ optional string prefix = 1;
+ }
+
+ message Unset {
+ // (Required) The config keys to unset.
+ repeated string keys = 1;
+ }
+
+ message IsModifiable {
+ // (Required) The config keys to check the config is modifiable.
+ repeated string keys = 1;
+ }
+}
+
+// Response to the config request.
+message ConfigResponse {
+ string client_id = 1;
+
+ // (Optional) The result key-value pairs.
+ //
+ // Available when the operation is 'Get', 'GetWithDefault', 'GetOption', 'GetAll'.
+ // Also available for the operation 'IsModifiable' with boolean string "true" and "false".
+ repeated KeyValue pairs = 2;
+
+ // (Optional)
+ //
+ // Warning messages for deprecated or unsupported configurations.
+ repeated string warnings = 3;
+}
+
// Main interface for the SparkConnect service.
service SparkConnectService {
@@ -193,5 +284,8 @@ service SparkConnectService {
// Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.
rpc AnalyzePlan(AnalyzePlanRequest) returns (AnalyzePlanResponse) {}
+
+ // Update or fetch the configurations and returns a [[ConfigResponse]] containing the result.
+ rpc Config(ConfigRequest) returns (ConfigResponse) {}
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala
new file mode 100644
index 00000000000..84f625222a8
--- /dev/null
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import scala.collection.JavaConverters._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.RuntimeConfig
+import org.apache.spark.sql.internal.SQLConf
+
+class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigResponse])
+ extends Logging {
+
+ def handle(request: proto.ConfigRequest): Unit = {
+ val session =
+ SparkConnectService
+ .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId)
+ .session
+
+ val builder = request.getOperation.getOpTypeCase match {
+ case proto.ConfigRequest.Operation.OpTypeCase.SET =>
+ handleSet(request.getOperation.getSet, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.GET =>
+ handleGet(request.getOperation.getGet, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.GET_WITH_DEFAULT =>
+ handleGetWithDefault(request.getOperation.getGetWithDefault, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.GET_OPTION =>
+ handleGetOption(request.getOperation.getGetOption, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.GET_ALL =>
+ handleGetAll(request.getOperation.getGetAll, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.UNSET =>
+ handleUnset(request.getOperation.getUnset, session.conf)
+ case proto.ConfigRequest.Operation.OpTypeCase.IS_MODIFIABLE =>
+ handleIsModifiable(request.getOperation.getIsModifiable, session.conf)
+ case _ => throw new UnsupportedOperationException(s"${request.getOperation} not supported.")
+ }
+
+ builder.setClientId(request.getClientId)
+ responseObserver.onNext(builder.build())
+ responseObserver.onCompleted()
+ }
+
+ private def handleSet(
+ operation: proto.ConfigRequest.Set,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getPairsList.asScala.iterator.foreach { pair =>
+ val (key, value) = SparkConnectConfigHandler.toKeyValue(pair)
+ conf.set(key, value.orNull)
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleGet(
+ operation: proto.ConfigRequest.Get,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getKeysList.asScala.iterator.foreach { key =>
+ val value = conf.get(key)
+ builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleGetWithDefault(
+ operation: proto.ConfigRequest.GetWithDefault,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getPairsList.asScala.iterator.foreach { pair =>
+ val (key, default) = SparkConnectConfigHandler.toKeyValue(pair)
+ val value = conf.get(key, default.orNull)
+ builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleGetOption(
+ operation: proto.ConfigRequest.GetOption,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getKeysList.asScala.iterator.foreach { key =>
+ val value = conf.getOption(key)
+ builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, value))
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleGetAll(
+ operation: proto.ConfigRequest.GetAll,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ val results = if (operation.hasPrefix) {
+ val prefix = operation.getPrefix
+ conf.getAll.iterator
+ .filter { case (key, _) => key.startsWith(prefix) }
+ .map { case (key, value) => (key.substring(prefix.length), value) }
+ } else {
+ conf.getAll.iterator
+ }
+ results.foreach { case (key, value) =>
+ builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleUnset(
+ operation: proto.ConfigRequest.Unset,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getKeysList.asScala.iterator.foreach { key =>
+ conf.unset(key)
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def handleIsModifiable(
+ operation: proto.ConfigRequest.IsModifiable,
+ conf: RuntimeConfig): proto.ConfigResponse.Builder = {
+ val builder = proto.ConfigResponse.newBuilder()
+ operation.getKeysList.asScala.iterator.foreach { key =>
+ val value = conf.isModifiable(key)
+ builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value.toString)))
+ getWarning(key).foreach(builder.addWarnings)
+ }
+ builder
+ }
+
+ private def getWarning(key: String): Option[String] = {
+ if (SparkConnectConfigHandler.unsupportedConfigurations.contains(key)) {
+ Some(s"The SQL config '$key' is NOT supported in Spark Connect")
+ } else {
+ SQLConf.deprecatedSQLConfigs.get(key).map(_.toDeprecationString)
+ }
+ }
+}
+
+object SparkConnectConfigHandler {
+
+ private[connect] val unsupportedConfigurations = Set("spark.sql.execution.arrow.enabled")
+
+ def toKeyValue(pair: proto.KeyValue): (String, Option[String]) = {
+ val key = pair.getKey
+ val value = if (pair.hasValue) {
+ Some(pair.getValue)
+ } else {
+ None
+ }
+ (key, value)
+ }
+
+ def toProtoKeyValue(key: String, value: Option[String]): proto.KeyValue = {
+ val builder = proto.KeyValue.newBuilder()
+ builder.setKey(key)
+ value.foreach(builder.setValue)
+ builder.build()
+ }
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 959aceaf46a..227067e2faf 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -212,6 +212,21 @@ class SparkConnectService(debug: Boolean)
response.setIsStreaming(ds.isStreaming)
response.addAllInputFiles(ds.inputFiles.toSeq.asJava)
}
+
+ /**
+ * This is the main entry method for Spark Connect and all calls to update or fetch
+ * configuration..
+ *
+ * @param request
+ * @param responseObserver
+ */
+ override def config(
+ request: proto.ConfigRequest,
+ responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
+ try {
+ new SparkConnectConfigHandler(responseObserver).handle(request)
+ } catch handleError("config", observer = responseObserver)
+ }
}
/**
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b849892e20a..b82f8dbb4d6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -508,6 +508,7 @@ pyspark_connect = Module(
python_test_goals=[
# doctests
"pyspark.sql.connect.catalog",
+ "pyspark.sql.connect.conf",
"pyspark.sql.connect.group",
"pyspark.sql.connect.session",
"pyspark.sql.connect.window",
@@ -523,6 +524,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_catalog",
+ "pyspark.sql.tests.connect.test_parity_conf",
"pyspark.sql.tests.connect.test_parity_serde",
"pyspark.sql.tests.connect.test_parity_functions",
"pyspark.sql.tests.connect.test_parity_group",
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index 9deb0147e66..c48dc8449cd 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -473,7 +473,7 @@ def default_session() -> SparkSession:
# Turn ANSI off when testing the pandas API on Spark since
# the behavior of pandas API on Spark follows pandas, not SQL.
if is_testing():
- spark.conf.set("spark.sql.ansi.enabled", False) # type: ignore[arg-type]
+ spark.conf.set("spark.sql.ansi.enabled", False)
if spark.conf.get("spark.sql.ansi.enabled") == "true":
log_advice(
"The config 'spark.sql.ansi.enabled' is set to True. "
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 40a36a26701..e8b258c9bf8 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -28,6 +28,9 @@ class RuntimeConfig:
"""User-facing configuration API, accessible through `SparkSession.conf`.
Options set here are automatically propagated to the Hadoop configuration during I/O.
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
"""
def __init__(self, jconf: JavaObject) -> None:
@@ -35,14 +38,23 @@ class RuntimeConfig:
self._jconf = jconf
@since(2.0)
- def set(self, key: str, value: str) -> None:
- """Sets the given Spark runtime configuration property."""
+ def set(self, key: str, value: Union[str, int, bool]) -> None:
+ """Sets the given Spark runtime configuration property.
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+ """
self._jconf.set(key, value)
@since(2.0)
- def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str:
+ def get(
+ self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue
+ ) -> Optional[str]:
"""Returns the value of Spark runtime configuration property for the given key,
assuming it is set.
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
"""
self._checkType(key, "key")
if default is _NoValue:
@@ -54,7 +66,11 @@ class RuntimeConfig:
@since(2.0)
def unset(self, key: str) -> None:
- """Resets the configuration property for the given key."""
+ """Resets the configuration property for the given key.
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+ """
self._jconf.unset(key)
def _checkType(self, obj: Any, identifier: str) -> None:
@@ -68,6 +84,9 @@ class RuntimeConfig:
def isModifiable(self, key: str) -> bool:
"""Indicates whether the configuration property with the given key
is modifiable in the current session.
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
"""
return self._jconf.isModifiable(key)
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 7ae4645863b..d6a1df6ba93 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -402,6 +402,19 @@ class AnalyzeResult:
)
+class ConfigResult:
+ def __init__(self, pairs: List[Tuple[str, Optional[str]]], warnings: List[str]):
+ self.pairs = pairs
+ self.warnings = warnings
+
+ @classmethod
+ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
+ return ConfigResult(
+ pairs=[(pair.key, pair.value if pair.HasField("value") else None) for pair in pb.pairs],
+ warnings=list(pb.warnings),
+ )
+
+
class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
@@ -736,6 +749,45 @@ class SparkConnectClient(object):
metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
return table, metrics
+ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
+ req = pb2.ConfigRequest()
+ req.client_id = self._session_id
+ req.client_type = self._builder.userAgent
+ if self._user_id:
+ req.user_context.user_id = self._user_id
+ return req
+
+ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
+ """
+ Call the config RPC of Spark Connect.
+
+ Parameters
+ ----------
+ operation : str
+ Operation kind
+
+ Returns
+ -------
+ The result of the config call.
+ """
+ req = self._config_request_with_metadata()
+ req.operation.CopyFrom(operation)
+ try:
+ for attempt in Retrying(
+ can_retry=SparkConnectClient.retry_exception, **self._retry_policy
+ ):
+ with attempt:
+ resp = self._stub.Config(req, metadata=self._builder.metadata())
+ if resp.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for request:"
+ f"{resp.client_id} != {self._session_id}"
+ )
+ return ConfigResult.fromProto(resp)
+ raise SparkConnectException("Invalid state during retry exception handling.")
+ except grpc.RpcError as rpc_error:
+ self._handle_error(rpc_error)
+
def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
diff --git a/python/pyspark/sql/connect/conf.py b/python/pyspark/sql/connect/conf.py
new file mode 100644
index 00000000000..d323de716c4
--- /dev/null
+++ b/python/pyspark/sql/connect/conf.py
@@ -0,0 +1,125 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, Optional, Union, cast
+import warnings
+
+from pyspark import _NoValue
+from pyspark._globals import _NoValueType
+from pyspark.sql.conf import RuntimeConfig as PySparkRuntimeConfig
+from pyspark.sql.connect import proto
+from pyspark.sql.connect.client import SparkConnectClient
+
+
+class RuntimeConf:
+ def __init__(self, client: SparkConnectClient) -> None:
+ """Create a new RuntimeConfig."""
+ self._client = client
+
+ __init__.__doc__ = PySparkRuntimeConfig.__init__.__doc__
+
+ def set(self, key: str, value: Union[str, int, bool]) -> None:
+ if isinstance(value, bool):
+ value = "true" if value else "false"
+ elif isinstance(value, int):
+ value = str(value)
+ op_set = proto.ConfigRequest.Set(pairs=[proto.KeyValue(key=key, value=value)])
+ operation = proto.ConfigRequest.Operation(set=op_set)
+ result = self._client.config(operation)
+ for warn in result.warnings:
+ warnings.warn(warn)
+
+ set.__doc__ = PySparkRuntimeConfig.set.__doc__
+
+ def get(
+ self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue
+ ) -> Optional[str]:
+ self._checkType(key, "key")
+ if default is _NoValue:
+ op_get = proto.ConfigRequest.Get(keys=[key])
+ operation = proto.ConfigRequest.Operation(get=op_get)
+ else:
+ if default is not None:
+ self._checkType(default, "default")
+ op_get_with_default = proto.ConfigRequest.GetWithDefault(
+ pairs=[proto.KeyValue(key=key, value=cast(Optional[str], default))]
+ )
+ operation = proto.ConfigRequest.Operation(get_with_default=op_get_with_default)
+ result = self._client.config(operation)
+ return result.pairs[0][1]
+
+ get.__doc__ = PySparkRuntimeConfig.get.__doc__
+
+ def unset(self, key: str) -> None:
+ op_unset = proto.ConfigRequest.Unset(keys=[key])
+ operation = proto.ConfigRequest.Operation(unset=op_unset)
+ result = self._client.config(operation)
+ for warn in result.warnings:
+ warnings.warn(warn)
+
+ unset.__doc__ = PySparkRuntimeConfig.unset.__doc__
+
+ def isModifiable(self, key: str) -> bool:
+ op_is_modifiable = proto.ConfigRequest.IsModifiable(keys=[key])
+ operation = proto.ConfigRequest.Operation(is_modifiable=op_is_modifiable)
+ result = self._client.config(operation)
+ if result.pairs[0][1] == "true":
+ return True
+ elif result.pairs[0][1] == "false":
+ return False
+ else:
+ raise ValueError(f"Unknown boolean value: {result.pairs[0][1]}")
+
+ isModifiable.__doc__ = PySparkRuntimeConfig.isModifiable.__doc__
+
+ def _checkType(self, obj: Any, identifier: str) -> None:
+ """Assert that an object is of type str."""
+ if not isinstance(obj, str):
+ raise TypeError(
+ "expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)
+ )
+
+
+RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__
+
+
+def _test() -> None:
+ import sys
+ import doctest
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.sql.connect.conf
+
+ globs = pyspark.sql.connect.conf.__dict__.copy()
+ globs["spark"] = (
+ PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate()
+ )
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.connect.conf,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS
+ | doctest.NORMALIZE_WHITESPACE
+ | doctest.IGNORE_EXCEPTION_DETAIL,
+ )
+
+ globs["spark"].stop()
+
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 7d61a86c8b5..87dfe90107d 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2477,11 +2477,6 @@ def _test() -> None:
# Spark Connect does not support Spark Context but the test depends on that.
del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
- # TODO(SPARK-41834): implement Dataframe.conf
- del pyspark.sql.connect.functions.from_unixtime.__doc__
- del pyspark.sql.connect.functions.timestamp_seconds.__doc__
- del pyspark.sql.connect.functions.unix_timestamp.__doc__
-
# TODO(SPARK-41843): Implement SparkSession.udf
del pyspark.sql.connect.functions.call_udf.__doc__
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 0d86ce8cd68..95951d8f8e3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
)
@@ -58,6 +58,17 @@ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY = (
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
"MetricValue"
]
+_KEYVALUE = DESCRIPTOR.message_types_by_name["KeyValue"]
+_CONFIGREQUEST = DESCRIPTOR.message_types_by_name["ConfigRequest"]
+_CONFIGREQUEST_OPERATION = _CONFIGREQUEST.nested_types_by_name["Operation"]
+_CONFIGREQUEST_SET = _CONFIGREQUEST.nested_types_by_name["Set"]
+_CONFIGREQUEST_GET = _CONFIGREQUEST.nested_types_by_name["Get"]
+_CONFIGREQUEST_GETWITHDEFAULT = _CONFIGREQUEST.nested_types_by_name["GetWithDefault"]
+_CONFIGREQUEST_GETOPTION = _CONFIGREQUEST.nested_types_by_name["GetOption"]
+_CONFIGREQUEST_GETALL = _CONFIGREQUEST.nested_types_by_name["GetAll"]
+_CONFIGREQUEST_UNSET = _CONFIGREQUEST.nested_types_by_name["Unset"]
+_CONFIGREQUEST_ISMODIFIABLE = _CONFIGREQUEST.nested_types_by_name["IsModifiable"]
+_CONFIGRESPONSE = DESCRIPTOR.message_types_by_name["ConfigResponse"]
_EXPLAIN_EXPLAINMODE = _EXPLAIN.enum_types_by_name["ExplainMode"]
Plan = _reflection.GeneratedProtocolMessageType(
"Plan",
@@ -186,6 +197,119 @@ _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricValue)
+KeyValue = _reflection.GeneratedProtocolMessageType(
+ "KeyValue",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _KEYVALUE,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.KeyValue)
+ },
+)
+_sym_db.RegisterMessage(KeyValue)
+
+ConfigRequest = _reflection.GeneratedProtocolMessageType(
+ "ConfigRequest",
+ (_message.Message,),
+ {
+ "Operation": _reflection.GeneratedProtocolMessageType(
+ "Operation",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_OPERATION,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Operation)
+ },
+ ),
+ "Set": _reflection.GeneratedProtocolMessageType(
+ "Set",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_SET,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Set)
+ },
+ ),
+ "Get": _reflection.GeneratedProtocolMessageType(
+ "Get",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_GET,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Get)
+ },
+ ),
+ "GetWithDefault": _reflection.GeneratedProtocolMessageType(
+ "GetWithDefault",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_GETWITHDEFAULT,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetWithDefault)
+ },
+ ),
+ "GetOption": _reflection.GeneratedProtocolMessageType(
+ "GetOption",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_GETOPTION,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetOption)
+ },
+ ),
+ "GetAll": _reflection.GeneratedProtocolMessageType(
+ "GetAll",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_GETALL,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetAll)
+ },
+ ),
+ "Unset": _reflection.GeneratedProtocolMessageType(
+ "Unset",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_UNSET,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Unset)
+ },
+ ),
+ "IsModifiable": _reflection.GeneratedProtocolMessageType(
+ "IsModifiable",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGREQUEST_ISMODIFIABLE,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.IsModifiable)
+ },
+ ),
+ "DESCRIPTOR": _CONFIGREQUEST,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest)
+ },
+)
+_sym_db.RegisterMessage(ConfigRequest)
+_sym_db.RegisterMessage(ConfigRequest.Operation)
+_sym_db.RegisterMessage(ConfigRequest.Set)
+_sym_db.RegisterMessage(ConfigRequest.Get)
+_sym_db.RegisterMessage(ConfigRequest.GetWithDefault)
+_sym_db.RegisterMessage(ConfigRequest.GetOption)
+_sym_db.RegisterMessage(ConfigRequest.GetAll)
+_sym_db.RegisterMessage(ConfigRequest.Unset)
+_sym_db.RegisterMessage(ConfigRequest.IsModifiable)
+
+ConfigResponse = _reflection.GeneratedProtocolMessageType(
+ "ConfigResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CONFIGRESPONSE,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ConfigResponse)
+ },
+)
+_sym_db.RegisterMessage(ConfigResponse)
+
_SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"]
if _descriptor._USE_C_DESCRIPTORS == False:
@@ -219,6 +343,28 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2017
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2019
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2107
- _SPARKCONNECTSERVICE._serialized_start = 2110
- _SPARKCONNECTSERVICE._serialized_end = 2309
+ _KEYVALUE._serialized_start = 2109
+ _KEYVALUE._serialized_end = 2174
+ _CONFIGREQUEST._serialized_start = 2177
+ _CONFIGREQUEST._serialized_end = 3203
+ _CONFIGREQUEST_OPERATION._serialized_start = 2395
+ _CONFIGREQUEST_OPERATION._serialized_end = 2893
+ _CONFIGREQUEST_SET._serialized_start = 2895
+ _CONFIGREQUEST_SET._serialized_end = 2947
+ _CONFIGREQUEST_GET._serialized_start = 2949
+ _CONFIGREQUEST_GET._serialized_end = 2974
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 2976
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 3039
+ _CONFIGREQUEST_GETOPTION._serialized_start = 3041
+ _CONFIGREQUEST_GETOPTION._serialized_end = 3072
+ _CONFIGREQUEST_GETALL._serialized_start = 3074
+ _CONFIGREQUEST_GETALL._serialized_end = 3122
+ _CONFIGREQUEST_UNSET._serialized_start = 3124
+ _CONFIGREQUEST_UNSET._serialized_end = 3151
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 3153
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 3187
+ _CONFIGRESPONSE._serialized_start = 3205
+ _CONFIGRESPONSE._serialized_end = 3325
+ _SPARKCONNECTSERVICE._serialized_start = 3328
+ _SPARKCONNECTSERVICE._serialized_end = 3600
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index ea82aaf21e2..f6c402b229f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -570,3 +570,345 @@ class ExecutePlanResponse(google.protobuf.message.Message):
) -> None: ...
global___ExecutePlanResponse = ExecutePlanResponse
+
+class KeyValue(google.protobuf.message.Message):
+ """The key-value pair for the config request and response."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: builtins.str
+ """(Required) The key."""
+ value: builtins.str
+ """(Optional) The value."""
+ def __init__(
+ self,
+ *,
+ key: builtins.str = ...,
+ value: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["_value", b"_value", "value", b"value"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_value", b"_value", "key", b"key", "value", b"value"
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_value", b"_value"]
+ ) -> typing_extensions.Literal["value"] | None: ...
+
+global___KeyValue = KeyValue
+
+class ConfigRequest(google.protobuf.message.Message):
+ """Request to update or fetch the configurations."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class Operation(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SET_FIELD_NUMBER: builtins.int
+ GET_FIELD_NUMBER: builtins.int
+ GET_WITH_DEFAULT_FIELD_NUMBER: builtins.int
+ GET_OPTION_FIELD_NUMBER: builtins.int
+ GET_ALL_FIELD_NUMBER: builtins.int
+ UNSET_FIELD_NUMBER: builtins.int
+ IS_MODIFIABLE_FIELD_NUMBER: builtins.int
+ @property
+ def set(self) -> global___ConfigRequest.Set: ...
+ @property
+ def get(self) -> global___ConfigRequest.Get: ...
+ @property
+ def get_with_default(self) -> global___ConfigRequest.GetWithDefault: ...
+ @property
+ def get_option(self) -> global___ConfigRequest.GetOption: ...
+ @property
+ def get_all(self) -> global___ConfigRequest.GetAll: ...
+ @property
+ def unset(self) -> global___ConfigRequest.Unset: ...
+ @property
+ def is_modifiable(self) -> global___ConfigRequest.IsModifiable: ...
+ def __init__(
+ self,
+ *,
+ set: global___ConfigRequest.Set | None = ...,
+ get: global___ConfigRequest.Get | None = ...,
+ get_with_default: global___ConfigRequest.GetWithDefault | None = ...,
+ get_option: global___ConfigRequest.GetOption | None = ...,
+ get_all: global___ConfigRequest.GetAll | None = ...,
+ unset: global___ConfigRequest.Unset | None = ...,
+ is_modifiable: global___ConfigRequest.IsModifiable | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "get",
+ b"get",
+ "get_all",
+ b"get_all",
+ "get_option",
+ b"get_option",
+ "get_with_default",
+ b"get_with_default",
+ "is_modifiable",
+ b"is_modifiable",
+ "op_type",
+ b"op_type",
+ "set",
+ b"set",
+ "unset",
+ b"unset",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "get",
+ b"get",
+ "get_all",
+ b"get_all",
+ "get_option",
+ b"get_option",
+ "get_with_default",
+ b"get_with_default",
+ "is_modifiable",
+ b"is_modifiable",
+ "op_type",
+ b"op_type",
+ "set",
+ b"set",
+ "unset",
+ b"unset",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["op_type", b"op_type"]
+ ) -> typing_extensions.Literal[
+ "set", "get", "get_with_default", "get_option", "get_all", "unset", "is_modifiable"
+ ] | None: ...
+
+ class Set(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ PAIRS_FIELD_NUMBER: builtins.int
+ @property
+ def pairs(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]:
+ """(Required) The config key-value pairs to set."""
+ def __init__(
+ self,
+ *,
+ pairs: collections.abc.Iterable[global___KeyValue] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["pairs", b"pairs"]) -> None: ...
+
+ class Get(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEYS_FIELD_NUMBER: builtins.int
+ @property
+ def keys(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) The config keys to get."""
+ def __init__(
+ self,
+ *,
+ keys: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ...
+
+ class GetWithDefault(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ PAIRS_FIELD_NUMBER: builtins.int
+ @property
+ def pairs(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]:
+ """(Required) The config key-value paris to get. The value will be used as the default value."""
+ def __init__(
+ self,
+ *,
+ pairs: collections.abc.Iterable[global___KeyValue] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["pairs", b"pairs"]) -> None: ...
+
+ class GetOption(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEYS_FIELD_NUMBER: builtins.int
+ @property
+ def keys(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) The config keys to get optionally."""
+ def __init__(
+ self,
+ *,
+ keys: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ...
+
+ class GetAll(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ PREFIX_FIELD_NUMBER: builtins.int
+ prefix: builtins.str
+ """(Optional) The prefix of the config key to get."""
+ def __init__(
+ self,
+ *,
+ prefix: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["_prefix", b"_prefix", "prefix", b"prefix"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["_prefix", b"_prefix", "prefix", b"prefix"]
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_prefix", b"_prefix"]
+ ) -> typing_extensions.Literal["prefix"] | None: ...
+
+ class Unset(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEYS_FIELD_NUMBER: builtins.int
+ @property
+ def keys(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) The config keys to unset."""
+ def __init__(
+ self,
+ *,
+ keys: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ...
+
+ class IsModifiable(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEYS_FIELD_NUMBER: builtins.int
+ @property
+ def keys(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) The config keys to check the config is modifiable."""
+ def __init__(
+ self,
+ *,
+ keys: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ...
+
+ CLIENT_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ OPERATION_FIELD_NUMBER: builtins.int
+ CLIENT_TYPE_FIELD_NUMBER: builtins.int
+ client_id: builtins.str
+ """(Required)
+
+ The client_id is set by the client to be able to collate streaming responses from
+ different queries.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """(Required) User context"""
+ @property
+ def operation(self) -> global___ConfigRequest.Operation:
+ """(Required) The operation for the config."""
+ client_type: builtins.str
+ """Provides optional information about the client sending the request. This field
+ can be used for language or version specific information and is only intended for
+ logging purposes and will not be interpreted by the server.
+ """
+ def __init__(
+ self,
+ *,
+ client_id: builtins.str = ...,
+ user_context: global___UserContext | None = ...,
+ operation: global___ConfigRequest.Operation | None = ...,
+ client_type: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "operation",
+ b"operation",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_id",
+ b"client_id",
+ "client_type",
+ b"client_type",
+ "operation",
+ b"operation",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"]
+ ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___ConfigRequest = ConfigRequest
+
+class ConfigResponse(google.protobuf.message.Message):
+ """Response to the config request."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ CLIENT_ID_FIELD_NUMBER: builtins.int
+ PAIRS_FIELD_NUMBER: builtins.int
+ WARNINGS_FIELD_NUMBER: builtins.int
+ client_id: builtins.str
+ @property
+ def pairs(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]:
+ """(Optional) The result key-value pairs.
+
+ Available when the operation is 'Get', 'GetWithDefault', 'GetOption', 'GetAll'.
+ Also available for the operation 'IsModifiable' with boolean string "true" and "false".
+ """
+ @property
+ def warnings(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Optional)
+
+ Warning messages for deprecated or unsupported configurations.
+ """
+ def __init__(
+ self,
+ *,
+ client_id: builtins.str = ...,
+ pairs: collections.abc.Iterable[global___KeyValue] | None = ...,
+ warnings: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "client_id", b"client_id", "pairs", b"pairs", "warnings", b"warnings"
+ ],
+ ) -> None: ...
+
+global___ConfigResponse = ConfigResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index aff5897f520..007e31fd0ea 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -40,6 +40,11 @@ class SparkConnectServiceStub(object):
request_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.FromString,
)
+ self.Config = channel.unary_unary(
+ "/spark.connect.SparkConnectService/Config",
+ request_serializer=spark_dot_connect_dot_base__pb2.ConfigRequest.SerializeToString,
+ response_deserializer=spark_dot_connect_dot_base__pb2.ConfigResponse.FromString,
+ )
class SparkConnectServiceServicer(object):
@@ -60,6 +65,12 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def Config(self, request, context):
+ """Update or fetch the configurations and returns a [[ConfigResponse]] containing the result."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -73,6 +84,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server):
request_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.SerializeToString,
),
+ "Config": grpc.unary_unary_rpc_method_handler(
+ servicer.Config,
+ request_deserializer=spark_dot_connect_dot_base__pb2.ConfigRequest.FromString,
+ response_serializer=spark_dot_connect_dot_base__pb2.ConfigResponse.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
@@ -141,3 +157,32 @@ class SparkConnectService(object):
timeout,
metadata,
)
+
+ @staticmethod
+ def Config(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/Config",
+ spark_dot_connect_dot_base__pb2.ConfigRequest.SerializeToString,
+ spark_dot_connect_dot_base__pb2.ConfigResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 08e63f544e2..c95279a8c8e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -47,6 +47,7 @@ from pandas.api.types import ( # type: ignore[attr-defined]
from pyspark import SparkContext, SparkConf, __version__
from pyspark.sql.connect.client import SparkConnectClient
+from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import SQL, Range, LocalRelation
from pyspark.sql.connect.readwriter import DataFrameReader
@@ -421,8 +422,8 @@ class SparkSession:
raise NotImplementedError("newSession() is not implemented.")
@property
- def conf(self) -> Any:
- raise NotImplementedError("conf() is not implemented.")
+ def conf(self) -> RuntimeConf:
+ return RuntimeConf(self.client)
@property
def sparkContext(self) -> Any:
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 3c47ebfb973..99f97977ccc 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -191,9 +191,11 @@ class SQLContext:
.. versionadded:: 1.3.0
"""
- self.sparkSession.conf.set(key, value) # type: ignore[arg-type]
+ self.sparkSession.conf.set(key, value)
- def getConf(self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue) -> str:
+ def getConf(
+ self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue
+ ) -> Optional[str]:
"""Returns the value of Spark SQL configuration property for the given key.
If the key is not set and defaultValue is set, return
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index adcd457a105..84c3e4f23a6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2796,7 +2796,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
for f in (
"newSession",
- "conf",
"sparkContext",
"streams",
"readStream",
diff --git a/python/pyspark/sql/tests/connect/test_parity_conf.py b/python/pyspark/sql/tests/connect/test_parity_conf.py
new file mode 100644
index 00000000000..554f05f27ea
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_conf.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.tests.test_conf import ConfTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ConfParityTests(ConfTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_conf import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 25fdbebd991..800fe4a2298 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -46,11 +46,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_invalid_join_method(self):
super().test_invalid_join_method()
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_join_without_on(self):
- super().test_join_without_on()
-
# TODO(SPARK-41527): Implement DataFrame.observe
@unittest.skip("Fails in Spark Connect, should enable.")
def test_observe(self):
@@ -75,11 +70,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_repr_behaviors(self):
super().test_repr_behaviors()
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_require_cross(self):
- super().test_require_cross()
-
# TODO(SPARK-41874): Implement DataFrame `sameSemantics`
@unittest.skip("Fails in Spark Connect, should enable.")
def test_same_semantics_error(self):
@@ -117,16 +107,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
# Spark Connect's implementation is based on Arrow.
super().check_to_pandas_for_array_of_struct(True)
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_to_pandas_from_empty_dataframe(self):
- super().test_to_pandas_from_empty_dataframe()
-
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_to_pandas_from_mixed_dataframe(self):
- super().test_to_pandas_from_mixed_dataframe()
-
# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_to_pandas_from_null_dataframe(self):
diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py
index a8fa59c0364..15722c2c57a 100644
--- a/python/pyspark/sql/tests/test_conf.py
+++ b/python/pyspark/sql/tests/test_conf.py
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from decimal import Decimal
+from pyspark.errors import IllegalArgumentException
from pyspark.testing.sqlutils import ReusedSQLTestCase
-class ConfTests(ReusedSQLTestCase):
+class ConfTestsMixin:
def test_conf(self):
spark = self.spark
spark.conf.set("bogo", "sipeo")
@@ -42,6 +44,31 @@ class ConfTests(ReusedSQLTestCase):
# `defaultValue` in `spark.conf.get` is set to None.
self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None)
+ self.assertTrue(spark.conf.isModifiable("spark.sql.execution.arrow.maxRecordsPerBatch"))
+ self.assertFalse(spark.conf.isModifiable("spark.sql.warehouse.dir"))
+
+ def test_conf_with_python_objects(self):
+ spark = self.spark
+
+ for value, expected in [(True, "true"), (False, "false")]:
+ spark.conf.set("foo", value)
+ self.assertEqual(spark.conf.get("foo"), expected)
+
+ spark.conf.set("foo", 1)
+ self.assertEqual(spark.conf.get("foo"), "1")
+
+ with self.assertRaises(IllegalArgumentException):
+ spark.conf.set("foo", None)
+
+ with self.assertRaises(Exception):
+ spark.conf.set("foo", Decimal(1))
+
+ spark.conf.unset("foo")
+
+
+class ConfTests(ConfTestsMixin, ReusedSQLTestCase):
+ pass
+
if __name__ == "__main__":
import unittest
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d015e7df32b..67a3f1b5fed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4177,7 +4177,12 @@ object SQLConf {
* @param comment Additional info regarding to the removed config. For example,
* reasons of config deprecation, what users should use instead of it.
*/
- case class DeprecatedConfig(key: String, version: String, comment: String)
+ case class DeprecatedConfig(key: String, version: String, comment: String) {
+ def toDeprecationString: String = {
+ s"The SQL config '$key' has been deprecated in Spark v$version " +
+ s"and may be removed in the future. $comment"
+ }
+ }
/**
* Maps deprecated SQL config keys to information about the deprecation.
@@ -5148,11 +5153,8 @@ class SQLConf extends Serializable with Logging {
* Logs a warning message if the given config key is deprecated.
*/
private def logDeprecationWarning(key: String): Unit = {
- SQLConf.deprecatedSQLConfigs.get(key).foreach {
- case DeprecatedConfig(configName, version, comment) =>
- logWarning(
- s"The SQL config '$configName' has been deprecated in Spark v$version " +
- s"and may be removed in the future. $comment")
+ SQLConf.deprecatedSQLConfigs.get(key).foreach { config =>
+ logWarning(config.toDeprecationString)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org