You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/28 05:36:00 UTC

[spark] branch branch-3.4 updated: [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0c4ad508476 [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized
0c4ad508476 is described below

commit 0c4ad508476b54d7d3acd303ff686310dd198a3d
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Tue Mar 28 14:35:38 2023 +0900

    [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized
    
    ### What changes were proposed in this pull request?
    Raise RuntimeError when SparkContext is required but not initialized.
    
    ### Why are the changes needed?
    Error improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    Error type and message change.
    
    Raise a RuntimeError with a clear message (rather than an AssertionError) when SparkContext is required but not initialized yet.
    
    ### How was this patch tested?
    Unit test.
    
    Closes #40534 from xinrong-meng/err_msg.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 70f6206dbcd3c5ff0f4618cf179b7fcf75ae672c)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/avro/functions.py     | 21 ++++++-----
 python/pyspark/sql/column.py             | 32 +++++++---------
 python/pyspark/sql/dataframe.py          | 10 +++--
 python/pyspark/sql/functions.py          | 65 +++++++++++++++-----------------
 python/pyspark/sql/protobuf/functions.py | 21 ++++++-----
 python/pyspark/sql/tests/test_udf.py     |  7 ++++
 python/pyspark/sql/types.py              | 17 ++++-----
 python/pyspark/sql/udf.py                |  4 +-
 python/pyspark/sql/utils.py              |  9 +++++
 python/pyspark/sql/window.py             | 40 +++++++++++---------
 10 files changed, 120 insertions(+), 106 deletions(-)

diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py
index cf6676c8ab1..080e45934e6 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -20,9 +20,12 @@ A collections of builtin avro functions
 """
 
 
-from typing import Dict, Optional, TYPE_CHECKING
-from pyspark import SparkContext
+from typing import Dict, Optional, TYPE_CHECKING, cast
+
+from py4j.java_gateway import JVMView
+
 from pyspark.sql.column import Column, _to_java_column
+from pyspark.sql.utils import get_active_spark_context
 from pyspark.util import _print_missing_jar
 
 if TYPE_CHECKING:
@@ -73,10 +76,9 @@ def from_avro(
     [Row(value=Row(avro=Row(age=2, name='Alice')))]
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     try:
-        jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro(
+        jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.from_avro(
             _to_java_column(data), jsonFormatSchema, options or {}
         )
     except TypeError as e:
@@ -119,13 +121,14 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
     [Row(suite=bytearray(b'\\x02\\x00'))]
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     try:
         if jsonFormatSchema == "":
-            jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data))
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro(
+                _to_java_column(data)
+            )
         else:
-            jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro(
                 _to_java_column(data), jsonFormatSchema
             )
     except TypeError as e:
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index d13d3954bca..0a18930b8eb 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -31,12 +31,13 @@ from typing import (
     Union,
 )
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
 
 from pyspark import copy_func
 from pyspark.context import SparkContext
 from pyspark.errors import PySparkTypeError
 from pyspark.sql.types import DataType
+from pyspark.sql.utils import get_active_spark_context
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral
@@ -46,15 +47,13 @@ __all__ = ["Column"]
 
 
 def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column":
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
-    return sc._jvm.functions.lit(literal)
+    sc = get_active_spark_context()
+    return cast(JVMView, sc._jvm).functions.lit(literal)
 
 
 def _create_column_from_name(name: str) -> "Column":
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
-    return sc._jvm.functions.col(name)
+    sc = get_active_spark_context()
+    return cast(JVMView, sc._jvm).functions.col(name)
 
 
 def _to_java_column(col: "ColumnOrName") -> JavaObject:
@@ -122,9 +121,8 @@ def _unary_op(
 
 def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
     def _(self: "Column") -> "Column":
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jc = getattr(sc._jvm.functions, name)(self._jc)
+        sc = get_active_spark_context()
+        jc = getattr(cast(JVMView, sc._jvm).functions, name)(self._jc)
         return Column(jc)
 
     _.__doc__ = doc
@@ -137,9 +135,8 @@ def _bin_func_op(
     doc: str = "binary function",
 ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]:
     def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column":
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        fn = getattr(sc._jvm.functions, name)
+        sc = get_active_spark_context()
+        fn = getattr(cast(JVMView, sc._jvm).functions, name)
         jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
         njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
         return Column(njc)
@@ -633,8 +630,7 @@ class Column:
         +--------------+
 
         """
-        sc = SparkContext._active_spark_context
-        assert sc is not None
+        sc = get_active_spark_context()
         jc = self._jc.dropFields(_to_seq(sc, fieldNames))
         return Column(jc)
 
@@ -962,8 +958,7 @@ class Column:
             Tuple,
             [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols],
         )
-        sc = SparkContext._active_spark_context
-        assert sc is not None
+        sc = get_active_spark_context()
         jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
         return Column(jc)
 
@@ -1144,8 +1139,7 @@ class Column:
         metadata = kwargs.pop("metadata", None)
         assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
 
-        sc = SparkContext._active_spark_context
-        assert sc is not None
+        sc = get_active_spark_context()
         if len(alias) == 1:
             if metadata:
                 assert sc._jvm is not None
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cc5d264bd34..518bc9867d7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -39,7 +39,7 @@ from typing import (
     TYPE_CHECKING,
 )
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
 
 from pyspark import copy_func, _NoValue
 from pyspark._globals import _NoValueType
@@ -61,6 +61,7 @@ from pyspark.sql.types import (
     Row,
     _parse_datatype_json_string,
 )
+from pyspark.sql.utils import get_active_spark_context
 from pyspark.sql.pandas.conversion import PandasConversionMixin
 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
@@ -4899,9 +4900,10 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                 error_class="NOT_DICT",
                 message_parameters={"arg_name": "metadata", "arg_type": type(metadata).__name__},
             )
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata))
+        sc = get_active_spark_context()
+        jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson(
+            json.dumps(metadata)
+        )
         return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession)
 
     @overload
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb5a1a559be..ab099554293 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -37,6 +37,8 @@ from typing import (
     ValuesView,
 )
 
+from py4j.java_gateway import JVMView
+
 from pyspark import SparkContext
 from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.rdd import PythonEvalType
@@ -49,7 +51,12 @@ from pyspark.sql.udf import UserDefinedFunction, _create_py_udf  # noqa: F401
 
 # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType  # noqa: F401
-from pyspark.sql.utils import to_str, has_numpy, try_remote_functions
+from pyspark.sql.utils import (
+    to_str,
+    has_numpy,
+    try_remote_functions,
+    get_active_spark_context,
+)
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import (
@@ -101,8 +108,7 @@ def _invoke_function_over_seq_of_columns(name: str, cols: "Iterable[ColumnOrName
     Invokes unary JVM function identified by name with
     and wraps the result with :class:`~pyspark.sql.Column`.
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function(name, _to_seq(sc, cols, _to_java_column))
 
 
@@ -2676,9 +2682,8 @@ def broadcast(df: DataFrame) -> DataFrame:
     +-----+---+
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
-    return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession)
+    sc = get_active_spark_context()
+    return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession)
 
 
 @try_remote_functions
@@ -2891,8 +2896,7 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
     |                           4|
     +----------------------------+
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function(
         "count_distinct", _to_java_column(col), _to_seq(sc, cols, _to_java_column)
     )
@@ -3304,8 +3308,7 @@ def percentile_approx(
      |-- key: long (nullable = true)
      |-- median: double (nullable = true)
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
 
     if isinstance(percentage, (list, tuple)):
         # A local list
@@ -6226,8 +6229,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> Column:
     >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
     [Row(s='abcd-123')]
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function("concat_ws", sep, _to_seq(sc, cols, _to_java_column))
 
 
@@ -6360,8 +6362,7 @@ def format_string(format: str, *cols: "ColumnOrName") -> Column:
     >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
     [Row(v='5 hello')]
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function("format_string", format, _to_seq(sc, cols, _to_java_column))
 
 
@@ -7419,8 +7420,7 @@ def array_join(
     >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
     [Row(joined='a,b,c'), Row(joined='a,NULL')]
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    get_active_spark_context()
     if null_replacement is None:
         return _invoke_function("array_join", _to_java_column(col), delimiter)
     else:
@@ -8229,8 +8229,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column:
     >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()
     [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function("json_tuple", _to_java_column(col), _to_seq(sc, fields))
 
 
@@ -9182,8 +9181,7 @@ def from_csv(
     [Row(csv=Row(s='abc'))]
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    get_active_spark_context()
     if isinstance(schema, str):
         schema = _create_column_from_literal(schema)
     elif isinstance(schema, Column):
@@ -9209,11 +9207,12 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column:
     ----------
     name_parts : str
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     name_parts_seq = _to_seq(sc, name_parts)
-    expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
-    return Column(sc._jvm.Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq)))
+    expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
+    return Column(
+        cast(JVMView, sc._jvm).Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq))
+    )
 
 
 def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
@@ -9258,9 +9257,8 @@ def _create_lambda(f: Callable) -> Callable:
     """
     parameters = _get_lambda_parameters(f)
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
-    expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
+    sc = get_active_spark_context()
+    expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
 
     argnames = ["x", "y", "z"]
     args = [
@@ -9300,15 +9298,14 @@ def _invoke_higher_order_function(
 
     :return: a Column
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
-    expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
+    sc = get_active_spark_context()
+    expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
     expr = getattr(expressions, name)
 
     jcols = [_to_java_column(col).expr() for col in cols]
     jfuns = [_create_lambda(f) for f in funs]
 
-    return Column(sc._jvm.Column(expr(*jcols + jfuns)))
+    return Column(cast(JVMView, sc._jvm).Column(expr(*jcols + jfuns)))
 
 
 @overload
@@ -10017,8 +10014,7 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column:
             message_parameters={"arg_name": "numBuckets", "arg_type": type(numBuckets).__name__},
         )
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    get_active_spark_context()
     numBuckets = (
         _create_column_from_literal(numBuckets)
         if isinstance(numBuckets, int)
@@ -10070,8 +10066,7 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
     |         cc|
     +-----------+
     """
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column))
 
 
diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py
index 1fed9cfda66..a303cf91493 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -20,9 +20,12 @@ A collections of builtin protobuf functions
 """
 
 
-from typing import Dict, Optional, TYPE_CHECKING
-from pyspark import SparkContext
+from typing import Dict, Optional, TYPE_CHECKING, cast
+
+from py4j.java_gateway import JVMView
+
 from pyspark.sql.column import Column, _to_java_column
+from pyspark.sql.utils import get_active_spark_context
 from pyspark.util import _print_missing_jar
 
 if TYPE_CHECKING:
@@ -117,15 +120,14 @@ def from_protobuf(
     +------------------+
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     try:
         if descFilePath is not None:
-            jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf(
                 _to_java_column(data), messageName, descFilePath, options or {}
             )
         else:
-            jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf(
                 _to_java_column(data), messageName, options or {}
             )
     except TypeError as e:
@@ -212,15 +214,14 @@ def to_protobuf(
     +----------------------------+
     """
 
-    sc = SparkContext._active_spark_context
-    assert sc is not None and sc._jvm is not None
+    sc = get_active_spark_context()
     try:
         if descFilePath is not None:
-            jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf(
                 _to_java_column(data), messageName, descFilePath, options or {}
             )
         else:
-            jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
+            jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf(
                 _to_java_column(data), messageName, options or {}
             )
 
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 1b414baeec3..d8a464b006f 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -858,6 +858,13 @@ class UDFInitializationTests(unittest.TestCase):
             "SparkSession shouldn't be initialized when UserDefinedFunction is created.",
         )
 
+    def test_err_parse_type_when_no_sc(self):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "SparkContext or SparkSession should be created first",
+        ):
+            udf(lambda x: x, "integer")
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_udf import *  # noqa: F401
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9cb17e85540..ff43e4b00e9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -46,10 +46,10 @@ from typing import (
 )
 
 from py4j.protocol import register_input_converter
-from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject
+from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject, JVMView
 
 from pyspark.serializers import CloudPickleSerializer
-from pyspark.sql.utils import has_numpy
+from pyspark.sql.utils import has_numpy, get_active_spark_context
 
 if has_numpy:
     import numpy as np
@@ -1208,21 +1208,18 @@ def _parse_datatype_string(s: str) -> DataType:
         ...
     ParseException:...
     """
-    from pyspark import SparkContext
-
-    sc = SparkContext._active_spark_context
-    assert sc is not None
+    sc = get_active_spark_context()
 
     def from_ddl_schema(type_str: str) -> DataType:
-        assert sc is not None and sc._jvm is not None
         return _parse_datatype_json_string(
-            sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()
+            cast(JVMView, sc._jvm).org.apache.spark.sql.types.StructType.fromDDL(type_str).json()
         )
 
     def from_ddl_datatype(type_str: str) -> DataType:
-        assert sc is not None and sc._jvm is not None
         return _parse_datatype_json_string(
-            sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()
+            cast(JVMView, sc._jvm)
+            .org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str)
+            .json()
         )
 
     try:
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index e79d04141ae..c1fa3d187fe 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -40,6 +40,7 @@ from pyspark.sql.types import (
     StructType,
     _parse_datatype_string,
 )
+from pyspark.sql.utils import get_active_spark_context
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
 
@@ -334,8 +335,7 @@ class UserDefinedFunction:
         return judf
 
     def __call__(self, *cols: "ColumnOrName") -> Column:
-        sc = SparkContext._active_spark_context
-        assert sc is not None
+        sc = get_active_spark_context()
         profiler: Optional[Profiler] = None
         memory_profiler: Optional[Profiler] = None
         if sc.profiler_collector:
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b9b045541a6..b5d17e38b87 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -193,6 +193,15 @@ def try_remote_windowspec(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def get_active_spark_context() -> SparkContext:
+    """Raise RuntimeError if SparkContext is not initialized,
+    otherwise, returns the active SparkContext."""
+    sc = SparkContext._active_spark_context
+    if sc is None or sc._jvm is None:
+        raise RuntimeError("SparkContext or SparkSession should be created first.")
+    return sc
+
+
 def try_remote_observation(f: FuncT) -> FuncT:
     """Mark API supported from Spark Connect."""
 
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 92b251ba63f..ca05cb0cc7f 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -17,11 +17,14 @@
 import sys
 from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
 
-from pyspark import SparkContext
 from pyspark.sql.column import _to_seq, _to_java_column
-from pyspark.sql.utils import try_remote_window, try_remote_windowspec
+from pyspark.sql.utils import (
+    try_remote_window,
+    try_remote_windowspec,
+    get_active_spark_context,
+)
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName, ColumnOrName_
@@ -30,10 +33,9 @@ __all__ = ["Window", "WindowSpec"]
 
 
 def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> JavaObject:
-    sc = SparkContext._active_spark_context
     if len(cols) == 1 and isinstance(cols[0], list):
         cols = cols[0]  # type: ignore[assignment]
-    assert sc is not None
+    sc = get_active_spark_context()
     return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column)
 
 
@@ -123,9 +125,10 @@ class Window:
         |  3|       b|         3|
         +---+--------+----------+
         """
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
+        sc = get_active_spark_context()
+        jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.partitionBy(
+            _to_java_cols(cols)
+        )
         return WindowSpec(jspec)
 
     @staticmethod
@@ -179,9 +182,10 @@ class Window:
         |  3|       b|         1|
         +---+--------+----------+
         """
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
+        sc = get_active_spark_context()
+        jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.orderBy(
+            _to_java_cols(cols)
+        )
         return WindowSpec(jspec)
 
     @staticmethod
@@ -263,9 +267,10 @@ class Window:
             start = Window.unboundedPreceding
         if end >= Window._FOLLOWING_THRESHOLD:
             end = Window.unboundedFollowing
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end)
+        sc = get_active_spark_context()
+        jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rowsBetween(
+            start, end
+        )
         return WindowSpec(jspec)
 
     @staticmethod
@@ -350,9 +355,10 @@ class Window:
             start = Window.unboundedPreceding
         if end >= Window._FOLLOWING_THRESHOLD:
             end = Window.unboundedFollowing
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._jvm is not None
-        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end)
+        sc = get_active_spark_context()
+        jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rangeBetween(
+            start, end
+        )
         return WindowSpec(jspec)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org