You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2022/11/25 22:12:52 UTC

[spark] branch master updated: [SPARK-41255][CONNECT] Rename RemoteSparkSession

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 77e2d453b14 [SPARK-41255][CONNECT] Rename RemoteSparkSession
77e2d453b14 is described below

commit 77e2d453b14eca1ab6740e6c532394fc908050f4
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Fri Nov 25 18:12:37 2022 -0400

    [SPARK-41255][CONNECT] Rename RemoteSparkSession
    
    ### What changes were proposed in this pull request?
    For better source compatibility, this PR changes the type name of RemoteSparkSession to SparkSession and follows the same builder pattern. The communication with the GRPC endpoint is kept in `client.py` whereas the public facing Spark Session related functionality is implemented in `SparkSession` in `session.py`.
    
    The new class does not support the full behavior of the existing Spark Session. To connect to Spark Connect using the new code use the following example:
    
    ```
    # Connection to a remote endpoint
    SparkSession.builder.remote("sc://endpoint/;config=abc").getOrCreate()
    ```
    
    or
    
    ```
    # Local connection to a locally running server
    SparkSession.builder.remote().getOrCreate()
    ```
    
    or
    
    ```
    SparkSession.builder.conf("spark.connect.location", "sc://endpoint").getOrCreate()
    ```
    
    ### Why are the changes needed?
    Compatibility.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #38792 from grundprinzip/SPARK-41255.
    
    Lead-authored-by: Martin Grund <ma...@databricks.com>
    Co-authored-by: Martin Grund <gr...@gmail.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 python/pyspark/sql/connect/client.py               |  57 +----
 python/pyspark/sql/connect/column.py               |  16 +-
 python/pyspark/sql/connect/dataframe.py            |  52 ++---
 python/pyspark/sql/connect/function_builder.py     |   4 +-
 python/pyspark/sql/connect/plan.py                 |  64 ++---
 python/pyspark/sql/connect/readwriter.py           |   4 +-
 python/pyspark/sql/connect/session.py              | 258 +++++++++++++++++++++
 .../sql/tests/connect/test_connect_basic.py        |   5 +-
 .../connect/test_connect_column_expressions.py     |   2 +-
 python/pyspark/testing/connectutils.py             |   6 +-
 10 files changed, 340 insertions(+), 128 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index b41df12c357..a2a0797c49f 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -15,23 +15,19 @@
 # limitations under the License.
 #
 
-
-import logging
 import os
 import urllib.parse
 import uuid
+from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
 
 import grpc  # type: ignore
-import pyarrow as pa
 import pandas
+import pyarrow as pa
 
 import pyspark.sql.connect.proto as pb2
 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 import pyspark.sql.types
 from pyspark import cloudpickle
-from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.readwriter import DataFrameReader
-from pyspark.sql.connect.plan import SQL, Range
 from pyspark.sql.types import (
     DataType,
     ByteType,
@@ -56,10 +52,6 @@ from pyspark.sql.types import (
     NullType,
 )
 
-from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
-
-logging.basicConfig(level=logging.INFO)
-
 
 class ChannelBuilder:
     """
@@ -294,12 +286,12 @@ class AnalyzeResult:
         )
 
 
-class RemoteSparkSession(object):
+class SparkConnectClient(object):
     """Conceptually the remote spark session that communicates with the server"""
 
-    def __init__(self, connectionString: str = "sc://localhost", userId: Optional[str] = None):
+    def __init__(self, connectionString: str, userId: Optional[str] = None):
         """
-        Creates a new RemoteSparkSession for the Spark Connect interface.
+        Creates a new SparkSession for the Spark Connect interface.
 
         Parameters
         ----------
@@ -325,9 +317,6 @@ class RemoteSparkSession(object):
         self._channel = self._builder.toChannel()
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
 
-        # Create the reader
-        self.read = DataFrameReader(self)
-
     def register_udf(
         self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
     ) -> str:
@@ -355,42 +344,6 @@ class RemoteSparkSession(object):
             for x in metrics.metrics
         ]
 
-    def sql(self, sql_string: str) -> "DataFrame":
-        return DataFrame.withPlan(SQL(sql_string), self)
-
-    def range(
-        self,
-        start: int,
-        end: int,
-        step: int = 1,
-        numPartitions: Optional[int] = None,
-    ) -> DataFrame:
-        """
-        Create a :class:`DataFrame` with column named ``id`` and typed Long,
-        containing elements in a range from ``start`` to ``end`` (exclusive) with
-        step value ``step``.
-
-        .. versionadded:: 3.4.0
-
-        Parameters
-        ----------
-        start : int
-            the start value
-        end : int
-            the end value (exclusive)
-        step : int, optional
-            the incremental step (default: 1)
-        numPartitions : int, optional
-            the number of partitions of the DataFrame
-
-        Returns
-        -------
-        :class:`DataFrame`
-        """
-        return DataFrame.withPlan(
-            Range(start=start, end=end, step=step, num_partitions=numPartitions), self
-        )
-
     def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 36f38e0ded2..69f9fa72db6 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -24,7 +24,7 @@ import datetime
 import pyspark.sql.connect.proto as proto
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.client import SparkConnectClient
     import pyspark.sql.connect.proto as proto
 
 
@@ -80,7 +80,7 @@ class Expression(object):
     def __init__(self) -> None:
         pass
 
-    def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         ...
 
     def __str__(self) -> str:
@@ -131,7 +131,7 @@ class ColumnAlias(Expression):
         self._metadata = metadata
         self._parent = parent
 
-    def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         if len(self._alias) == 1:
             exp = proto.Expression()
             exp.alias.name.append(self._alias[0])
@@ -162,7 +162,7 @@ class LiteralExpression(Expression):
         super().__init__()
         self._value = value
 
-    def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         """Converts the literal expression to the literal in proto.
 
         TODO(SPARK-40533) This method always assumes the largest type and can thus
@@ -250,7 +250,7 @@ class Column(Expression):
         """Returns the qualified name of the column reference."""
         return self._unparsed_identifier
 
-    def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         """Returns the Proto representation of the expression."""
         expr = proto.Expression()
         expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
@@ -275,7 +275,7 @@ class SQLExpression(Expression):
         super().__init__()
         self._expr: str = expr
 
-    def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         """Returns the Proto representation of the SQL expression."""
         expr = proto.Expression()
         expr.expression_string.expression = self._expr
@@ -292,7 +292,7 @@ class SortOrder(Expression):
     def __str__(self) -> str:
         return str(self.ref) + " ASC" if self.ascending else " DESC"
 
-    def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         return self.ref.to_plan(session)
 
 
@@ -306,7 +306,7 @@ class ScalarFunctionExpression(Expression):
         self._args = args
         self._op = op
 
-    def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         fun = proto.Expression()
         fun.unresolved_function.parts.append(self._op)
         fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args])
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 23340e46165..6fabab69cf5 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -46,7 +46,7 @@ from pyspark.sql.types import (
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.session import SparkSession
 
 
 class GroupedData(object):
@@ -97,20 +97,20 @@ class DataFrame(object):
 
     def __init__(
         self,
-        session: "RemoteSparkSession",
+        session: "SparkSession",
         data: Optional[List[Any]] = None,
         schema: Optional[StructType] = None,
     ):
         """Creates a new data frame"""
         self._schema = schema
         self._plan: Optional[plan.LogicalPlan] = None
-        self._session: "RemoteSparkSession" = session
+        self._session: "SparkSession" = session
 
     def __repr__(self) -> str:
         return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
 
     @classmethod
-    def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame":
+    def withPlan(cls, plan: plan.LogicalPlan, session: "SparkSession") -> "DataFrame":
         """Main initialization method used to construct a new data frame with a child plan."""
         new_frame = DataFrame(session=session)
         new_frame._plan = plan
@@ -197,14 +197,14 @@ class DataFrame(object):
 
         return self.schema.names
 
-    def sparkSession(self) -> "RemoteSparkSession":
+    def sparkSession(self) -> "SparkSession":
         """Returns Spark session that created this :class:`DataFrame`.
 
         .. versionadded:: 3.4.0
 
         Returns
         -------
-        :class:`RemoteSparkSession`
+        :class:`SparkSession`
         """
         return self._session
 
@@ -796,8 +796,8 @@ class DataFrame(object):
             raise Exception("Cannot collect on empty plan.")
         if self._session is None:
             raise Exception("Cannot collect on empty session.")
-        query = self._plan.to_proto(self._session)
-        return self._session._to_pandas(query)
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client._to_pandas(query)
 
     @property
     def schema(self) -> StructType:
@@ -811,10 +811,10 @@ class DataFrame(object):
         """
         if self._schema is None:
             if self._plan is not None:
-                query = self._plan.to_proto(self._session)
+                query = self._plan.to_proto(self._session.client)
                 if self._session is None:
-                    raise Exception("Cannot analyze without RemoteSparkSession.")
-                self._schema = self._session.schema(query)
+                    raise Exception("Cannot analyze without SparkSession.")
+                self._schema = self._session.client.schema(query)
                 return self._schema
             else:
                 raise Exception("Empty plan.")
@@ -834,8 +834,8 @@ class DataFrame(object):
         """
         if self._plan is None:
             raise Exception("Cannot analyze on empty plan.")
-        query = self._plan.to_proto(self._session)
-        return self._session._analyze(query).is_local
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client._analyze(query).is_local
 
     @property
     def isStreaming(self) -> bool:
@@ -859,14 +859,14 @@ class DataFrame(object):
         """
         if self._plan is None:
             raise Exception("Cannot analyze on empty plan.")
-        query = self._plan.to_proto(self._session)
-        return self._session._analyze(query).is_streaming
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client._analyze(query).is_streaming
 
     def _tree_string(self) -> str:
         if self._plan is None:
             raise Exception("Cannot analyze on empty plan.")
-        query = self._plan.to_proto(self._session)
-        return self._session._analyze(query).tree_string
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client._analyze(query).tree_string
 
     def printSchema(self) -> None:
         """Prints out the schema in the tree format.
@@ -895,8 +895,8 @@ class DataFrame(object):
         """
         if self._plan is None:
             raise Exception("Cannot analyze on empty plan.")
-        query = self._plan.to_proto(self._session)
-        return self._session._analyze(query).input_files
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client._analyze(query).input_files
 
     def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
         """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
@@ -1011,10 +1011,10 @@ class DataFrame(object):
             explain_mode = cast(str, extended)
 
         if self._plan is not None:
-            query = self._plan.to_proto(self._session)
+            query = self._plan.to_proto(self._session.client)
             if self._session is None:
-                raise Exception("Cannot analyze without RemoteSparkSession.")
-            return self._session.explain_string(query, explain_mode)
+                raise Exception("Cannot analyze without SparkSession.")
+            return self._session.client.explain_string(query, explain_mode)
         else:
             return ""
 
@@ -1032,8 +1032,8 @@ class DataFrame(object):
         """
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=False
-        ).command(session=self._session)
-        self._session.execute_command(command)
+        ).command(session=self._session.client)
+        self._session.client.execute_command(command)
 
     def createOrReplaceGlobalTempView(self, name: str) -> None:
         """Creates or replaces a global temporary view using the given name.
@@ -1049,8 +1049,8 @@ class DataFrame(object):
         """
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=True
-        ).command(session=self._session)
-        self._session.execute_command(command)
+        ).command(session=self._session.client)
+        self._session.client.execute_command(command)
 
     def rdd(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("RDD Support for Spark Connect is not implemented.")
diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py
index 4a2688d6a0d..8df5e56b452 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
         FunctionBuilderCallable,
         UserDefinedFunctionCallable,
     )
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.client import SparkConnectClient
 
 
 def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
@@ -91,7 +91,7 @@ class UserDefinedFunction(Expression):
             self._args = []
         self._func_name = None
 
-    def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         if session is None:
             raise Exception("CAnnot create UDF without remote Session.")
         # Needs to materialize the UDF to the server
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 853b1a6dc0e..9a22d6ea38e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.column import (
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.client import SparkConnectClient
 
 
 class InputValidationError(Exception):
@@ -57,7 +57,7 @@ class LogicalPlan(object):
         return exp
 
     def to_attr_or_expression(
-        self, col: "ColumnOrName", session: "RemoteSparkSession"
+        self, col: "ColumnOrName", session: "SparkConnectClient"
     ) -> proto.Expression:
         """Returns either an instance of an unresolved attribute or the serialized
         expression value of the column."""
@@ -66,13 +66,13 @@ class LogicalPlan(object):
         else:
             return cast(Column, col).to_plan(session)
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         ...
 
-    def command(self, session: "RemoteSparkSession") -> proto.Command:
+    def command(self, session: "SparkConnectClient") -> proto.Command:
         ...
 
-    def _verify(self, session: "RemoteSparkSession") -> bool:
+    def _verify(self, session: "SparkConnectClient") -> bool:
         """This method is used to verify that the current logical plan
         can be serialized to Proto and back and afterwards is identical."""
         plan = proto.Plan()
@@ -84,13 +84,13 @@ class LogicalPlan(object):
 
         return test_plan == plan
 
-    def to_proto(self, session: "RemoteSparkSession", debug: bool = False) -> proto.Plan:
+    def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan:
         """
         Generates connect proto plan based on this LogicalPlan.
 
         Parameters
         ----------
-        session : :class:`RemoteSparkSession`, optional.
+        session : :class:`SparkConnectClient`, optional.
             a session that connects remote spark cluster.
         debug: bool
             if enabled, the proto plan will be printed.
@@ -127,7 +127,7 @@ class DataSource(LogicalPlan):
         self.schema = schema
         self.options = options
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = proto.Relation()
         if self.format is not None:
             plan.read.data_source.format = self.format
@@ -158,7 +158,7 @@ class Read(LogicalPlan):
         super().__init__(None)
         self.table_name = table_name
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = proto.Relation()
         plan.read.named_table.unparsed_identifier = self.table_name
         return plan
@@ -186,7 +186,7 @@ class ShowString(LogicalPlan):
         self.truncate = truncate
         self.vertical = vertical
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.show_string.input.CopyFrom(self._child.plan(session))
@@ -242,7 +242,7 @@ class Project(LogicalPlan):
                     f"Only Expressions or String can be used for projections: '{c}'."
                 )
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         proj_exprs = []
         for c in self._raw_columns:
@@ -281,7 +281,7 @@ class Filter(LogicalPlan):
         super().__init__(child)
         self.filter = filter
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.filter.input.CopyFrom(self._child.plan(session))
@@ -309,7 +309,7 @@ class Limit(LogicalPlan):
         super().__init__(child)
         self.limit = limit
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.limit.input.CopyFrom(self._child.plan(session))
@@ -337,7 +337,7 @@ class Offset(LogicalPlan):
         super().__init__(child)
         self.offset = offset
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.offset.input.CopyFrom(self._child.plan(session))
@@ -371,7 +371,7 @@ class Deduplicate(LogicalPlan):
         self.all_columns_as_keys = all_columns_as_keys
         self.column_names = column_names
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
@@ -411,7 +411,7 @@ class Sort(LogicalPlan):
         self.is_global = is_global
 
     def col_to_sort_field(
-        self, col: Union[SortOrder, Column, str], session: "RemoteSparkSession"
+        self, col: Union[SortOrder, Column, str], session: "SparkConnectClient"
     ) -> proto.Sort.SortField:
         if isinstance(col, SortOrder):
             sf = proto.Sort.SortField()
@@ -438,7 +438,7 @@ class Sort(LogicalPlan):
             sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST
             return sf
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.sort.input.CopyFrom(self._child.plan(session))
@@ -474,7 +474,7 @@ class Drop(LogicalPlan):
         self.columns = columns
 
     def _convert_to_expr(
-        self, col: Union[Column, str], session: "RemoteSparkSession"
+        self, col: Union[Column, str], session: "SparkConnectClient"
     ) -> proto.Expression:
         expr = proto.Expression()
         if isinstance(col, Column):
@@ -483,7 +483,7 @@ class Drop(LogicalPlan):
             expr.CopyFrom(self.unresolved_attr(col))
         return expr
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.drop.input.CopyFrom(self._child.plan(session))
@@ -521,7 +521,7 @@ class Sample(LogicalPlan):
         self.with_replacement = with_replacement
         self.seed = seed
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.sample.input.CopyFrom(self._child.plan(session))
@@ -567,12 +567,12 @@ class Aggregate(LogicalPlan):
         self.grouping_cols = grouping_cols
         self.measures = measures
 
-    def _convert_measure(self, m: Expression, session: "RemoteSparkSession") -> proto.Expression:
+    def _convert_measure(self, m: Expression, session: "SparkConnectClient") -> proto.Expression:
         proto_expr = proto.Expression()
         proto_expr.CopyFrom(m.to_plan(session))
         return proto_expr
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         groupings = [x.to_plan(session) for x in self.grouping_cols]
 
@@ -642,7 +642,7 @@ class Join(LogicalPlan):
             )
         self.how = join_type
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         rel.join.left.CopyFrom(self.left.plan(session))
         rel.join.right.CopyFrom(self.right.plan(session))
@@ -693,7 +693,7 @@ class SetOperation(LogicalPlan):
         self.is_all = is_all
         self.set_op = set_op
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         rel = proto.Relation()
         if self._child is not None:
@@ -753,7 +753,7 @@ class Repartition(LogicalPlan):
         self._num_partitions = num_partitions
         self._shuffle = shuffle
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         if self._child is not None:
             rel.repartition.input.CopyFrom(self._child.plan(session))
@@ -786,7 +786,7 @@ class SubqueryAlias(LogicalPlan):
         super().__init__(child)
         self._alias = alias
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         if self._child is not None:
             rel.subquery_alias.input.CopyFrom(self._child.plan(session))
@@ -814,7 +814,7 @@ class SQL(LogicalPlan):
         super().__init__(None)
         self._query = query
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         rel.sql.query = self._query
         return rel
@@ -849,7 +849,7 @@ class Range(LogicalPlan):
         self._step = step
         self._num_partitions = num_partitions
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         rel.range.start = self._start
         rel.range.end = self._end
@@ -912,7 +912,7 @@ class NAFill(LogicalPlan):
             value.string = v
         return value
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.fill_na.input.CopyFrom(self._child.plan(session))
@@ -942,7 +942,7 @@ class StatSummary(LogicalPlan):
         super().__init__(child)
         self.statistics = statistics
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = proto.Relation()
         plan.summary.input.CopyFrom(self._child.plan(session))
@@ -971,7 +971,7 @@ class StatCrosstab(LogicalPlan):
         self.col1 = col1
         self.col2 = col2
 
-    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
 
         plan = proto.Relation()
@@ -1006,7 +1006,7 @@ class CreateView(LogicalPlan):
         self._is_gloal = is_global
         self._replace = replace
 
-    def command(self, session: "RemoteSparkSession") -> proto.Command:
+    def command(self, session: "SparkConnectClient") -> proto.Command:
         assert self._child is not None
 
         plan = proto.Command()
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 27aa023ae47..ead027c206b 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.session import SparkSession
 
 
 class DataFrameReader:
@@ -34,7 +34,7 @@ class DataFrameReader:
     TODO(SPARK-40539) Achieve parity with PySpark.
     """
 
-    def __init__(self, client: "RemoteSparkSession"):
+    def __init__(self, client: "SparkSession"):
         self._client = client
         self._format = ""
         self._schema = ""
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
new file mode 100644
index 00000000000..92f58140eac
--- /dev/null
+++ b/python/pyspark/sql/connect/session.py
@@ -0,0 +1,258 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from threading import RLock
+from typing import Optional, Any, Union, Dict, cast, overload
+
+import pyspark.sql.types
+from pyspark.sql.connect.client import SparkConnectClient
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.connect.plan import SQL, Range
+from pyspark.sql.connect.readwriter import DataFrameReader
+from pyspark.sql.utils import to_str
+from ._typing import OptionalPrimitiveType
+
+
+# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped
+# In Python 3.9, the @property decorator has been made compatible with the
+# @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod)
+#
+# @classmethod + @property is also affected by a bug in Python's docstring which was backported
+# to Python 3.9.6 (https://github.com/python/cpython/pull/28838)
+class classproperty(property):
+    """Same as Python's @property decorator, but for class attributes.
+
+    Examples
+    --------
+    >>> class Builder:
+    ...    def build(self):
+    ...        return MyClass()
+    ...
+    >>> class MyClass:
+    ...     @classproperty
+    ...     def builder(cls):
+    ...         print("instantiating new builder")
+    ...         return Builder()
+    ...
+    >>> c1 = MyClass.builder
+    instantiating new builder
+    >>> c2 = MyClass.builder
+    instantiating new builder
+    >>> c1 == c2
+    False
+    >>> isinstance(c1.build(), MyClass)
+    True
+    """
+
+    def __get__(self, instance: Any, owner: Any = None) -> "SparkSession.Builder":
+        # The "type: ignore" below silences the following error from mypy:
+        # error: Argument 1 to "classmethod" has incompatible
+        # type "Optional[Callable[[Any], Any]]";
+        # expected "Callable[..., Any]"  [arg-type]
+        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
+
+
+class SparkSession(object):
+    """Conceptually the remote spark session that communicates with the server"""
+
+    class Builder:
+        """Builder for :class:`SparkSession`."""
+
+        _lock = RLock()
+
+        def __init__(self) -> None:
+            self._options: Dict[str, Any] = {}
+
+        @overload
+        def config(self, key: str, value: Any) -> "SparkSession.Builder":
+            ...
+
+        @overload
+        def config(self, *, map: Dict[str, "OptionalPrimitiveType"]) -> "SparkSession.Builder":
+            ...
+
+        def config(
+            self,
+            key: Optional[str] = None,
+            value: Optional[Any] = None,
+            *,
+            map: Optional[Dict[str, "OptionalPrimitiveType"]] = None,
+        ) -> "SparkSession.Builder":
+            """Sets a config option. Options set using this method are automatically propagated to
+            both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
+
+            .. versionadded:: 2.0.0
+
+            Parameters
+            ----------
+            key : str, optional
+                a key name string for configuration property
+            value : str, optional
+                a value for configuration property
+            map: dictionary, optional
+                a dictionary of configurations to set
+
+                .. versionadded:: 3.4.0
+
+            Returns
+            -------
+            :class:`SparkSession.Builder`
+
+            Examples
+            --------
+            For a (key, value) pair, you can omit parameter names.
+
+            >>> SparkSession.builder.config("spark.some.config.option", "some-value")
+            <pyspark.sql.session.SparkSession.Builder...
+
+            Additionally, you can pass a dictionary of configurations to set.
+
+            >>> SparkSession.builder.config(
+            ...     map={"spark.some.config.number": 123, "spark.some.config.float": 0.123})
+            <pyspark.sql.session.SparkSession.Builder...
+            """
+            with self._lock:
+                if map is not None:
+                    for k, v in map.items():
+                        self._options[k] = to_str(v)
+                else:
+                    self._options[cast(str, key)] = to_str(value)
+                return self
+
+        def master(self, master: str) -> "SparkSession.Builder":
+            return self
+
+        def appName(self, name: str) -> "SparkSession.Builder":
+            """Sets a name for the application, which will be shown in the Spark web UI.
+
+            If no application name is set, a randomly generated name will be used.
+
+            .. versionadded:: 2.0.0
+
+            Parameters
+            ----------
+            name : str
+                an application name
+
+            Returns
+            -------
+            :class:`SparkSession.Builder`
+
+            Examples
+            --------
+            >>> SparkSession.builder.appName("My app")
+            <pyspark.sql.session.SparkSession.Builder...
+            """
+            return self.config("spark.app.name", name)
+
+        def remote(self, location: str = "sc://localhost") -> "SparkSession.Builder":
+            return self.config("spark.connect.location", location)
+
+        def enableHiveSupport(self) -> "SparkSession.Builder":
+            raise NotImplementedError("enableHiveSupport not  implemented for Spark Connect")
+
+        def getOrCreate(self) -> "SparkSession":
+            """Creates a new instance."""
+            return SparkSession(connectionString=self._options["spark.connect.location"])
+
+    _client: SparkConnectClient
+
+    # TODO(SPARK-38912): Replace @classproperty with @classmethod + @property once support for
+    # Python 3.8 is dropped.
+    #
+    # In Python 3.9, the @property decorator has been made compatible with the
+    # @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod)
+    #
+    # @classmethod + @property is also affected by a bug in Python's docstring which was backported
+    # to Python 3.9.6 (https://github.com/python/cpython/pull/28838)
+    @classproperty
+    def builder(cls) -> Builder:
+        """Creates a :class:`Builder` for constructing a :class:`SparkSession`."""
+        return cls.Builder()
+
+    def __init__(self, connectionString: str, userId: Optional[str] = None):
+        """
+        Creates a new SparkSession for the Spark Connect interface.
+
+        Parameters
+        ----------
+        connectionString: Optional[str]
+            Connection string that is used to extract the connection parameters and configure
+            the GRPC connection. Defaults to `sc://localhost`.
+        userId : Optional[str]
+            Optional unique user ID that is used to differentiate multiple users and
+            isolate their Spark Sessions. If the `user_id` is not set, will default to
+            the $USER environment. Defining the user ID as part of the connection string
+            takes precedence.
+        """
+        # Parse the connection string.
+        self._client = SparkConnectClient(connectionString)
+
+        # Create the reader
+        self.read = DataFrameReader(self)
+
+    @property
+    def client(self) -> "SparkConnectClient":
+        """
+        Gives access to the Spark Connect client. In normal cases this is not necessary to be used
+        and only relevant for testing.
+        Returns
+        -------
+        :class:`SparkConnectClient`
+        """
+        return self._client
+
+    def register_udf(
+        self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
+    ) -> str:
+        return self._client.register_udf(function, return_type)
+
+    def sql(self, sql_string: str) -> "DataFrame":
+        return DataFrame.withPlan(SQL(sql_string), self)
+
+    def range(
+        self,
+        start: int,
+        end: int,
+        step: int = 1,
+        numPartitions: Optional[int] = None,
+    ) -> DataFrame:
+        """
+        Create a :class:`DataFrame` with column named ``id`` and typed Long,
+        containing elements in a range from ``start`` to ``end`` (exclusive) with
+        step value ``step``.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        start : int
+            the start value
+        end : int
+            the end value (exclusive)
+        step : int, optional
+            the incremental step (default: 1)
+        numPartitions : int, optional
+            the number of partitions of the DataFrame
+
+        Returns
+        -------
+        :class:`DataFrame`
+        """
+        return DataFrame.withPlan(
+            Range(start=start, end=end, step=step, num_partitions=numPartitions), self
+        )
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 845d6ead567..150bbdb65ef 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -30,7 +30,8 @@ from pyspark.sql import SparkSession, Row
 from pyspark.sql.types import StructType, StructField, LongType, StringType
 
 if have_pandas:
-    from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder
+    from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
+    from pyspark.sql.connect.client import ChannelBuilder
     from pyspark.sql.connect.function_builder import udf
     from pyspark.sql.connect.functions import lit, col
 from pyspark.sql.dataframe import DataFrame
@@ -79,7 +80,7 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQLT
     @classmethod
     def spark_connect_load_test_data(cls: Any):
         # Setup Remote Spark Session
-        cls.connect = RemoteSparkSession(userId="test_user")
+        cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
         df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
         # Since we might create multiple Spark sessions, we need to create global temporary view
         # that is specifically maintained in the "global_temp" schema.
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 99b63482a24..03966bd28df 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -140,7 +140,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         self.assertEqual("Alias(Column(a), (martin))", str(col0))
 
         col0 = fun.col("a").alias("martin", metadata={"pii": True})
-        plan = col0.to_plan(self.session)
+        plan = col0.to_plan(self.session.client)
         self.assertIsNotNone(plan)
         self.assertEqual(plan.alias.metadata, '{"pii": true}')
 
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index f98a67b9964..feca9e9f825 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -26,7 +26,7 @@ if have_pandas:
     from pyspark.sql.connect.plan import Read, Range, SQL
     from pyspark.testing.utils import search_jar
     from pyspark.sql.connect.plan import LogicalPlan
-    from pyspark.sql.connect.client import RemoteSparkSession
+    from pyspark.sql.connect.session import SparkSession
 
     connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
 else:
@@ -69,7 +69,7 @@ class MockRemoteSession:
 class PlanOnlyTestFixture(unittest.TestCase):
 
     connect: "MockRemoteSession"
-    session: RemoteSparkSession
+    session: SparkSession
 
     @classmethod
     def _read_table(cls, table_name: str) -> "DataFrame":
@@ -102,7 +102,7 @@ class PlanOnlyTestFixture(unittest.TestCase):
     @classmethod
     def setUpClass(cls: Any) -> None:
         cls.connect = MockRemoteSession()
-        cls.session = RemoteSparkSession()
+        cls.session = SparkSession.builder.remote().getOrCreate()
         cls.tbl_name = "test_connect_plan_only_table_1"
 
         cls.connect.set_hook("register_udf", cls._udf_mock)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org