You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/28 05:36:00 UTC
[spark] branch branch-3.4 updated: [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 0c4ad508476 [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized
0c4ad508476 is described below
commit 0c4ad508476b54d7d3acd303ff686310dd198a3d
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Tue Mar 28 14:35:38 2023 +0900
[SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized
### What changes were proposed in this pull request?
Raise RuntimeError when SparkContext is required but not initialized.
### Why are the changes needed?
Error improvement.
### Does this PR introduce _any_ user-facing change?
Error type and message change.
Raise a RuntimeError with a clear message (rather than an AssertionError) when SparkContext is required but not initialized yet.
### How was this patch tested?
Unit test.
Closes #40534 from xinrong-meng/err_msg.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 70f6206dbcd3c5ff0f4618cf179b7fcf75ae672c)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/avro/functions.py | 21 ++++++-----
python/pyspark/sql/column.py | 32 +++++++---------
python/pyspark/sql/dataframe.py | 10 +++--
python/pyspark/sql/functions.py | 65 +++++++++++++++-----------------
python/pyspark/sql/protobuf/functions.py | 21 ++++++-----
python/pyspark/sql/tests/test_udf.py | 7 ++++
python/pyspark/sql/types.py | 17 ++++-----
python/pyspark/sql/udf.py | 4 +-
python/pyspark/sql/utils.py | 9 +++++
python/pyspark/sql/window.py | 40 +++++++++++---------
10 files changed, 120 insertions(+), 106 deletions(-)
diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py
index cf6676c8ab1..080e45934e6 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -20,9 +20,12 @@ A collections of builtin avro functions
"""
-from typing import Dict, Optional, TYPE_CHECKING
-from pyspark import SparkContext
+from typing import Dict, Optional, TYPE_CHECKING, cast
+
+from py4j.java_gateway import JVMView
+
from pyspark.sql.column import Column, _to_java_column
+from pyspark.sql.utils import get_active_spark_context
from pyspark.util import _print_missing_jar
if TYPE_CHECKING:
@@ -73,10 +76,9 @@ def from_avro(
[Row(value=Row(avro=Row(age=2, name='Alice')))]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
try:
- jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.from_avro(
_to_java_column(data), jsonFormatSchema, options or {}
)
except TypeError as e:
@@ -119,13 +121,14 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
[Row(suite=bytearray(b'\\x02\\x00'))]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
try:
if jsonFormatSchema == "":
- jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data))
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro(
+ _to_java_column(data)
+ )
else:
- jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro(
_to_java_column(data), jsonFormatSchema
)
except TypeError as e:
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index d13d3954bca..0a18930b8eb 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -31,12 +31,13 @@ from typing import (
Union,
)
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
from pyspark import copy_func
from pyspark.context import SparkContext
from pyspark.errors import PySparkTypeError
from pyspark.sql.types import DataType
+from pyspark.sql.utils import get_active_spark_context
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral
@@ -46,15 +47,13 @@ __all__ = ["Column"]
def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column":
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- return sc._jvm.functions.lit(literal)
+ sc = get_active_spark_context()
+ return cast(JVMView, sc._jvm).functions.lit(literal)
def _create_column_from_name(name: str) -> "Column":
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- return sc._jvm.functions.col(name)
+ sc = get_active_spark_context()
+ return cast(JVMView, sc._jvm).functions.col(name)
def _to_java_column(col: "ColumnOrName") -> JavaObject:
@@ -122,9 +121,8 @@ def _unary_op(
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
def _(self: "Column") -> "Column":
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jc = getattr(sc._jvm.functions, name)(self._jc)
+ sc = get_active_spark_context()
+ jc = getattr(cast(JVMView, sc._jvm).functions, name)(self._jc)
return Column(jc)
_.__doc__ = doc
@@ -137,9 +135,8 @@ def _bin_func_op(
doc: str = "binary function",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]:
def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column":
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- fn = getattr(sc._jvm.functions, name)
+ sc = get_active_spark_context()
+ fn = getattr(cast(JVMView, sc._jvm).functions, name)
jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
return Column(njc)
@@ -633,8 +630,7 @@ class Column:
+--------------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None
+ sc = get_active_spark_context()
jc = self._jc.dropFields(_to_seq(sc, fieldNames))
return Column(jc)
@@ -962,8 +958,7 @@ class Column:
Tuple,
[c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols],
)
- sc = SparkContext._active_spark_context
- assert sc is not None
+ sc = get_active_spark_context()
jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
return Column(jc)
@@ -1144,8 +1139,7 @@ class Column:
metadata = kwargs.pop("metadata", None)
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
- sc = SparkContext._active_spark_context
- assert sc is not None
+ sc = get_active_spark_context()
if len(alias) == 1:
if metadata:
assert sc._jvm is not None
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cc5d264bd34..518bc9867d7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -39,7 +39,7 @@ from typing import (
TYPE_CHECKING,
)
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
from pyspark import copy_func, _NoValue
from pyspark._globals import _NoValueType
@@ -61,6 +61,7 @@ from pyspark.sql.types import (
Row,
_parse_datatype_json_string,
)
+from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
@@ -4899,9 +4900,10 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
error_class="NOT_DICT",
message_parameters={"arg_name": "metadata", "arg_type": type(metadata).__name__},
)
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata))
+ sc = get_active_spark_context()
+ jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson(
+ json.dumps(metadata)
+ )
return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession)
@overload
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb5a1a559be..ab099554293 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -37,6 +37,8 @@ from typing import (
ValuesView,
)
+from py4j.java_gateway import JVMView
+
from pyspark import SparkContext
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.rdd import PythonEvalType
@@ -49,7 +51,12 @@ from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401
-from pyspark.sql.utils import to_str, has_numpy, try_remote_functions
+from pyspark.sql.utils import (
+ to_str,
+ has_numpy,
+ try_remote_functions,
+ get_active_spark_context,
+)
if TYPE_CHECKING:
from pyspark.sql._typing import (
@@ -101,8 +108,7 @@ def _invoke_function_over_seq_of_columns(name: str, cols: "Iterable[ColumnOrName
Invokes unary JVM function identified by name with
and wraps the result with :class:`~pyspark.sql.Column`.
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function(name, _to_seq(sc, cols, _to_java_column))
@@ -2676,9 +2682,8 @@ def broadcast(df: DataFrame) -> DataFrame:
+-----+---+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession)
+ sc = get_active_spark_context()
+ return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession)
@try_remote_functions
@@ -2891,8 +2896,7 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
| 4|
+----------------------------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function(
"count_distinct", _to_java_column(col), _to_seq(sc, cols, _to_java_column)
)
@@ -3304,8 +3308,7 @@ def percentile_approx(
|-- key: long (nullable = true)
|-- median: double (nullable = true)
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
if isinstance(percentage, (list, tuple)):
# A local list
@@ -6226,8 +6229,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> Column:
>>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
[Row(s='abcd-123')]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function("concat_ws", sep, _to_seq(sc, cols, _to_java_column))
@@ -6360,8 +6362,7 @@ def format_string(format: str, *cols: "ColumnOrName") -> Column:
>>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
[Row(v='5 hello')]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function("format_string", format, _to_seq(sc, cols, _to_java_column))
@@ -7419,8 +7420,7 @@ def array_join(
>>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
[Row(joined='a,b,c'), Row(joined='a,NULL')]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ get_active_spark_context()
if null_replacement is None:
return _invoke_function("array_join", _to_java_column(col), delimiter)
else:
@@ -8229,8 +8229,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column:
>>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()
[Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function("json_tuple", _to_java_column(col), _to_seq(sc, fields))
@@ -9182,8 +9181,7 @@ def from_csv(
[Row(csv=Row(s='abc'))]
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ get_active_spark_context()
if isinstance(schema, str):
schema = _create_column_from_literal(schema)
elif isinstance(schema, Column):
@@ -9209,11 +9207,12 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column:
----------
name_parts : str
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
name_parts_seq = _to_seq(sc, name_parts)
- expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
- return Column(sc._jvm.Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq)))
+ expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
+ return Column(
+ cast(JVMView, sc._jvm).Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq))
+ )
def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
@@ -9258,9 +9257,8 @@ def _create_lambda(f: Callable) -> Callable:
"""
parameters = _get_lambda_parameters(f)
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
+ sc = get_active_spark_context()
+ expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
argnames = ["x", "y", "z"]
args = [
@@ -9300,15 +9298,14 @@ def _invoke_higher_order_function(
:return: a Column
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
+ sc = get_active_spark_context()
+ expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions
expr = getattr(expressions, name)
jcols = [_to_java_column(col).expr() for col in cols]
jfuns = [_create_lambda(f) for f in funs]
- return Column(sc._jvm.Column(expr(*jcols + jfuns)))
+ return Column(cast(JVMView, sc._jvm).Column(expr(*jcols + jfuns)))
@overload
@@ -10017,8 +10014,7 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column:
message_parameters={"arg_name": "numBuckets", "arg_type": type(numBuckets).__name__},
)
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ get_active_spark_context()
numBuckets = (
_create_column_from_literal(numBuckets)
if isinstance(numBuckets, int)
@@ -10070,8 +10066,7 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
| cc|
+-----------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column))
diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py
index 1fed9cfda66..a303cf91493 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -20,9 +20,12 @@ A collections of builtin protobuf functions
"""
-from typing import Dict, Optional, TYPE_CHECKING
-from pyspark import SparkContext
+from typing import Dict, Optional, TYPE_CHECKING, cast
+
+from py4j.java_gateway import JVMView
+
from pyspark.sql.column import Column, _to_java_column
+from pyspark.sql.utils import get_active_spark_context
from pyspark.util import _print_missing_jar
if TYPE_CHECKING:
@@ -117,15 +120,14 @@ def from_protobuf(
+------------------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
try:
if descFilePath is not None:
- jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf(
_to_java_column(data), messageName, descFilePath, options or {}
)
else:
- jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf(
_to_java_column(data), messageName, options or {}
)
except TypeError as e:
@@ -212,15 +214,14 @@ def to_protobuf(
+----------------------------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
+ sc = get_active_spark_context()
try:
if descFilePath is not None:
- jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf(
_to_java_column(data), messageName, descFilePath, options or {}
)
else:
- jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
+ jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf(
_to_java_column(data), messageName, options or {}
)
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 1b414baeec3..d8a464b006f 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -858,6 +858,13 @@ class UDFInitializationTests(unittest.TestCase):
"SparkSession shouldn't be initialized when UserDefinedFunction is created.",
)
+ def test_err_parse_type_when_no_sc(self):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "SparkContext or SparkSession should be created first",
+ ):
+ udf(lambda x: x, "integer")
+
if __name__ == "__main__":
from pyspark.sql.tests.test_udf import * # noqa: F401
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9cb17e85540..ff43e4b00e9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -46,10 +46,10 @@ from typing import (
)
from py4j.protocol import register_input_converter
-from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject
+from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject, JVMView
from pyspark.serializers import CloudPickleSerializer
-from pyspark.sql.utils import has_numpy
+from pyspark.sql.utils import has_numpy, get_active_spark_context
if has_numpy:
import numpy as np
@@ -1208,21 +1208,18 @@ def _parse_datatype_string(s: str) -> DataType:
...
ParseException:...
"""
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- assert sc is not None
+ sc = get_active_spark_context()
def from_ddl_schema(type_str: str) -> DataType:
- assert sc is not None and sc._jvm is not None
return _parse_datatype_json_string(
- sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()
+ cast(JVMView, sc._jvm).org.apache.spark.sql.types.StructType.fromDDL(type_str).json()
)
def from_ddl_datatype(type_str: str) -> DataType:
- assert sc is not None and sc._jvm is not None
return _parse_datatype_json_string(
- sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()
+ cast(JVMView, sc._jvm)
+ .org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str)
+ .json()
)
try:
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index e79d04141ae..c1fa3d187fe 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -40,6 +40,7 @@ from pyspark.sql.types import (
StructType,
_parse_datatype_string,
)
+from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
@@ -334,8 +335,7 @@ class UserDefinedFunction:
return judf
def __call__(self, *cols: "ColumnOrName") -> Column:
- sc = SparkContext._active_spark_context
- assert sc is not None
+ sc = get_active_spark_context()
profiler: Optional[Profiler] = None
memory_profiler: Optional[Profiler] = None
if sc.profiler_collector:
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b9b045541a6..b5d17e38b87 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -193,6 +193,15 @@ def try_remote_windowspec(f: FuncT) -> FuncT:
return cast(FuncT, wrapped)
+def get_active_spark_context() -> SparkContext:
+ """Raise RuntimeError if SparkContext is not initialized,
+ otherwise, returns the active SparkContext."""
+ sc = SparkContext._active_spark_context
+ if sc is None or sc._jvm is None:
+ raise RuntimeError("SparkContext or SparkSession should be created first.")
+ return sc
+
+
def try_remote_observation(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 92b251ba63f..ca05cb0cc7f 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -17,11 +17,14 @@
import sys
from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JVMView
-from pyspark import SparkContext
from pyspark.sql.column import _to_seq, _to_java_column
-from pyspark.sql.utils import try_remote_window, try_remote_windowspec
+from pyspark.sql.utils import (
+ try_remote_window,
+ try_remote_windowspec,
+ get_active_spark_context,
+)
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName, ColumnOrName_
@@ -30,10 +33,9 @@ __all__ = ["Window", "WindowSpec"]
def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> JavaObject:
- sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
- assert sc is not None
+ sc = get_active_spark_context()
return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column)
@@ -123,9 +125,10 @@ class Window:
| 3| b| 3|
+---+--------+----------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
+ sc = get_active_spark_context()
+ jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.partitionBy(
+ _to_java_cols(cols)
+ )
return WindowSpec(jspec)
@staticmethod
@@ -179,9 +182,10 @@ class Window:
| 3| b| 1|
+---+--------+----------+
"""
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
+ sc = get_active_spark_context()
+ jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.orderBy(
+ _to_java_cols(cols)
+ )
return WindowSpec(jspec)
@staticmethod
@@ -263,9 +267,10 @@ class Window:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end)
+ sc = get_active_spark_context()
+ jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rowsBetween(
+ start, end
+ )
return WindowSpec(jspec)
@staticmethod
@@ -350,9 +355,10 @@ class Window:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._jvm is not None
- jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end)
+ sc = get_active_spark_context()
+ jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rangeBetween(
+ start, end
+ )
return WindowSpec(jspec)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org