You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/11 02:59:15 UTC

[spark] branch master updated: [SPARK-40281][PYTHON] Memory Profiler on Executors

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0cb79a4af27 [SPARK-40281][PYTHON] Memory Profiler on Executors
0cb79a4af27 is described below

commit 0cb79a4af27eb226a4f7574aa1c8b54412222583
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Fri Nov 11 11:59:00 2022 +0900

    [SPARK-40281][PYTHON] Memory Profiler on Executors
    
    ### What changes were proposed in this pull request?
    Implement PySpark memory profiling on executors. The feature is enabled via a newly-introduced Spark configuration `spark.python.profile.memory`.
    
    See more [design.](https://docs.google.com/document/d/e/2PACX-1vR2K4TdrM1eAjNDC1bsflCNRH67UWLoC-lCv6TSUVXD91Ruksm99pYTnCeIm7Ui3RgrrRNcQU_D8-oh/pub)
    
    ### Why are the changes needed?
    There are many factors in a PySpark program’s performance. Memory, as one of the key factors of a program’s performance, had been missing in PySpark profiling. A PySpark program on the Spark driver can be profiled with [Memory Profiler](https://www.google.com/url?q=https://pypi.org/project/memory-profiler/&sa=D&source=editors&ust=1668027860192689&usg=AOvVaw1t4LRcObEGuhaTr5oHEUwU) as a normal Python process, but there was not an easy way to profile memory on Spark executors.
    
    PySpark UDFs, one of the most popular Python APIs, enable users to run custom code on top of the Apache Spark™ engine. However, it is difficult to optimize UDFs without understanding memory consumption.
    
    The PR proposes to introduce the PySpark memory profiler, which profiles memory on executors. It provides information about total memory usage and pinpoints which lines of code in a UDF attribute to the most memory usage. That will help optimize PySpark UDFs and reduce the likelihood of out-of-memory errors.
    
    ### Does this PR introduce _any_ user-facing change?
    No changes to existing user-facing behaviors.
    
    A Spark configuration `spark.python.profile.memory` is introduced to enable the PySpark memory profiling feature.
    
    ### How was this patch tested?
    - Unit tests.
    - Manual tests on Jupyter notebooks as shown below:
    ![image](https://user-images.githubusercontent.com/47337188/200998618-73eb5bd1-83ba-4256-9ba4-f6fb4afcd1bd.png)
    
    Closes #38584 from xinrong-meng/memory_profile.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 dev/requirements.txt                         |   1 +
 dev/sparktestsupport/modules.py              |   1 +
 python/pyspark/context.py                    |  22 +-
 python/pyspark/profiler.py                   | 297 +++++++++++++++++++++++++--
 python/pyspark/rdd.py                        |   5 +-
 python/pyspark/sql/udf.py                    |  70 +++++--
 python/pyspark/tests/test_memory_profiler.py | 160 +++++++++++++++
 python/pyspark/tests/test_profiler.py        |  57 ++++-
 8 files changed, 570 insertions(+), 43 deletions(-)

diff --git a/dev/requirements.txt b/dev/requirements.txt
index 2f32066d6a8..914c26b1fa1 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -10,6 +10,7 @@ plotly
 mlflow>=1.0
 sklearn
 matplotlib<3.3.0
+memory-profiler==0.60.0
 
 # PySpark test dependencies
 unittest-xml-reporting
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index a439b4cbbed..159990fb33a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -418,6 +418,7 @@ pyspark_core = Module(
         "pyspark.tests.test_daemon",
         "pyspark.tests.test_install_spark",
         "pyspark.tests.test_join",
+        "pyspark.tests.test_memory_profiler",
         "pyspark.tests.test_profiler",
         "pyspark.tests.test_rdd",
         "pyspark.tests.test_rddbarrier",
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 45e683efe7a..9a7a8f46e84 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -67,7 +67,7 @@ from pyspark.rdd import RDD, _load_from_socket
 from pyspark.taskcontext import TaskContext
 from pyspark.traceback_utils import CallSite, first_spark_call
 from pyspark.status import StatusTracker
-from pyspark.profiler import ProfilerCollector, BasicProfiler, UDFBasicProfiler
+from pyspark.profiler import ProfilerCollector, BasicProfiler, UDFBasicProfiler, MemoryProfiler
 from py4j.java_gateway import is_instance_of, JavaGateway, JavaObject, JVMView
 
 if TYPE_CHECKING:
@@ -177,6 +177,7 @@ class SparkContext:
         jsc: Optional[JavaObject] = None,
         profiler_cls: Type[BasicProfiler] = BasicProfiler,
         udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler,
+        memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler,
     ):
 
         if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true":
@@ -204,6 +205,7 @@ class SparkContext:
                 jsc,
                 profiler_cls,
                 udf_profiler_cls,
+                memory_profiler_cls,
             )
         except BaseException:
             # If an error occurs, clean up in order to allow future SparkContext creation:
@@ -223,6 +225,7 @@ class SparkContext:
         jsc: JavaObject,
         profiler_cls: Type[BasicProfiler] = BasicProfiler,
         udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler,
+        memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler,
     ) -> None:
         self.environment = environment or {}
         # java gateway must have been launched at this point.
@@ -354,9 +357,14 @@ class SparkContext:
         ).getAbsolutePath()
 
         # profiling stats collected for each PythonRDD
-        if self._conf.get("spark.python.profile", "false") == "true":
+        if (
+            self._conf.get("spark.python.profile", "false") == "true"
+            or self._conf.get("spark.python.profile.memory", "false") == "true"
+        ):
             dump_path = self._conf.get("spark.python.profile.dump", None)
-            self.profiler_collector = ProfilerCollector(profiler_cls, udf_profiler_cls, dump_path)
+            self.profiler_collector = ProfilerCollector(
+                profiler_cls, udf_profiler_cls, memory_profiler_cls, dump_path
+            )
         else:
             self.profiler_collector = None  # type: ignore[assignment]
 
@@ -2320,8 +2328,8 @@ class SparkContext:
             self.profiler_collector.show_profiles()
         else:
             raise RuntimeError(
-                "'spark.python.profile' configuration must be set "
-                "to 'true' to enable Python profile."
+                "'spark.python.profile' or 'spark.python.profile.memory' configuration"
+                " must be set to 'true' to enable Python profile."
             )
 
     def dump_profiles(self, path: str) -> None:
@@ -2337,8 +2345,8 @@ class SparkContext:
             self.profiler_collector.dump_profiles(path)
         else:
             raise RuntimeError(
-                "'spark.python.profile' configuration must be set "
-                "to 'true' to enable Python profile."
+                "'spark.python.profile' or 'spark.python.profile.memory' configuration"
+                " must be set to 'true' to enable Python profile."
             )
 
     def getConf(self) -> SparkConf:
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index 45365cc1e79..cb668b874a2 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -15,19 +15,44 @@
 # limitations under the License.
 #
 
-from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, cast
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TYPE_CHECKING,
+    Union,
+    cast,
+)
 
 import cProfile
+import inspect
 import pstats
+import linecache
 import os
 import atexit
 import sys
+import warnings
+
+try:
+    from memory_profiler import choose_backend, CodeMap, LineProfiler  # type: ignore[import]
+
+    has_memory_profiler = True
+except Exception:
+    has_memory_profiler = False
 
 from pyspark.accumulators import AccumulatorParam
 
 if TYPE_CHECKING:
     from pyspark.context import SparkContext
 
+MemoryTuple = Tuple[float, float, int]
+LineProfile = Tuple[int, Optional[MemoryTuple]]
+CodeMapDict = Dict[str, List[LineProfile]]
+
 
 class ProfilerCollector:
     """
@@ -40,10 +65,12 @@ class ProfilerCollector:
         self,
         profiler_cls: Type["Profiler"],
         udf_profiler_cls: Type["Profiler"],
+        memory_profiler_cls: Type["Profiler"],
         dump_path: Optional[str] = None,
     ):
         self.profiler_cls: Type[Profiler] = profiler_cls
         self.udf_profiler_cls: Type[Profiler] = udf_profiler_cls
+        self.memory_profiler_cls: Type[Profiler] = memory_profiler_cls
         self.profile_dump_path: Optional[str] = dump_path
         self.profilers: List[List[Any]] = []
 
@@ -55,6 +82,10 @@ class ProfilerCollector:
         """Create a new profiler using class `udf_profiler_cls`"""
         return self.udf_profiler_cls(ctx)
 
+    def new_memory_profiler(self, ctx: "SparkContext") -> "Profiler":
+        """Create a new profiler using class `memory_profiler_cls`"""
+        return self.memory_profiler_cls(ctx)
+
     def add_profiler(self, id: int, profiler: "Profiler") -> None:
         """Add a profiler for RDD/UDF `id`"""
         if not self.profilers:
@@ -125,27 +156,95 @@ class Profiler:
         """Do profiling on the function `func`"""
         raise NotImplementedError
 
-    def stats(self) -> pstats.Stats:
-        """Return the collected profiling stats (pstats.Stats)"""
+    def stats(self) -> Union[pstats.Stats, Dict]:
+        """Return the collected profiling stats"""
         raise NotImplementedError
 
     def show(self, id: int) -> None:
-        """Print the profile stats to stdout, id is the RDD id"""
-        stats = self.stats()
-        if stats:
-            print("=" * 60)
-            print("Profile of RDD<id=%d>" % id)
-            print("=" * 60)
-            stats.sort_stats("time", "cumulative").print_stats()
+        """Print the profile stats to stdout"""
+        raise NotImplementedError
 
     def dump(self, id: int, path: str) -> None:
-        """Dump the profile into path, id is the RDD id"""
-        if not os.path.exists(path):
-            os.makedirs(path)
-        stats = self.stats()
-        if stats:
-            p = os.path.join(path, "rdd_%d.pstats" % id)
-            stats.dump_stats(p)
+        """Dump the profile into path"""
+        raise NotImplementedError
+
+
+if has_memory_profiler:
+
+    class CodeMapForUDF(CodeMap):
+        def add(
+            self,
+            code: Any,
+            toplevel_code: Optional[Any] = None,
+            *,
+            sub_lines: Optional[List] = None,
+            start_line: Optional[int] = None,
+        ) -> None:
+            if code in self:
+                return
+
+            if toplevel_code is None:
+                toplevel_code = code
+                filename = code.co_filename
+                if sub_lines is None or start_line is None:
+                    (sub_lines, start_line) = inspect.getsourcelines(code)
+                linenos = range(start_line, start_line + len(sub_lines))
+                self._toplevel.append((filename, code, linenos))
+                self[code] = {}
+            else:
+                self[code] = self[toplevel_code]
+            for subcode in filter(inspect.iscode, code.co_consts):
+                self.add(subcode, toplevel_code=toplevel_code)
+
+    class UDFLineProfiler(LineProfiler):
+        def __init__(self, **kw: Any) -> None:
+            include_children = kw.get("include_children", False)
+            backend = kw.get("backend", "psutil")
+            self.code_map = CodeMapForUDF(include_children=include_children, backend=backend)
+            self.enable_count = 0
+            self.max_mem = kw.get("max_mem", None)
+            self.prevlines: List = []
+            self.backend = choose_backend(kw.get("backend", None))
+            self.prev_lineno = None
+
+        def __call__(
+            self,
+            func: Optional[Callable[..., Any]] = None,
+            precision: int = 1,
+            *,
+            sub_lines: Optional[List] = None,
+            start_line: Optional[int] = None,
+        ) -> Callable[..., Any]:
+            if func is not None:
+                self.add_function(func, sub_lines=sub_lines, start_line=start_line)
+                f = self.wrap_function(func)
+                f.__module__ = func.__module__
+                f.__name__ = func.__name__
+                f.__doc__ = func.__doc__
+                f.__dict__.update(getattr(func, "__dict__", {}))
+                return f
+            else:
+
+                def inner_partial(f: Callable[..., Any]) -> Any:
+                    return self.__call__(f, precision=precision)
+
+                return inner_partial
+
+        def add_function(
+            self,
+            func: Callable[..., Any],
+            *,
+            sub_lines: Optional[List] = None,
+            start_line: Optional[int] = None,
+        ) -> None:
+            """Record line profiling information for the given Python function."""
+            try:
+                # func_code does not exist in Python3
+                code = func.__code__
+            except AttributeError:
+                warnings.warn("Could not extract a code object for the object %r" % func)
+            else:
+                self.code_map.add(code, sub_lines=sub_lines, start_line=start_line)
 
 
 class PStatsParam(AccumulatorParam[Optional[pstats.Stats]]):
@@ -165,6 +264,53 @@ class PStatsParam(AccumulatorParam[Optional[pstats.Stats]]):
         return value1
 
 
+class MemUsageParam(AccumulatorParam[Optional[CodeMapDict]]):
+    """MemUsageParam is used to merge memory usage code map"""
+
+    @staticmethod
+    def zero(value: Optional[CodeMapDict]) -> None:
+        return None
+
+    @staticmethod
+    def addInPlace(
+        value1: Optional[CodeMapDict], value2: Optional[CodeMapDict]
+    ) -> Optional[CodeMapDict]:
+        # An example value looks as below
+        # {'<command-1598004922717618>': [(3, (144.2578125, 144.2578125, 1)),
+        #   (4, (0.0, 144.2578125, 1))]}
+        if value1 is None or len(value1) == 0:
+            return value2
+        if value2 is None or len(value2) == 0:
+            return value1
+
+        # value1, value2 should have same keys - file name
+        for filename in value1:
+            l1 = cast(List[LineProfile], value1.get(filename))
+            l2 = cast(List[LineProfile], value2.get(filename))
+            c1 = dict((k, v) for k, v in l1)
+            c2 = dict((k, v) for k, v in l2)
+            udf_code_map: Dict[int, Optional[MemoryTuple]] = {}
+            for lineno in c1:
+                if c1[lineno] and c2[lineno]:
+                    # c1, c2 should have same keys - line number
+                    udf_code_map[lineno] = (
+                        cast(MemoryTuple, c1[lineno])[0]
+                        + cast(MemoryTuple, c2[lineno])[0],  # increment
+                        cast(MemoryTuple, c1[lineno])[1]
+                        + cast(MemoryTuple, c2[lineno])[1],  # mem_usage
+                        cast(MemoryTuple, c1[lineno])[2]
+                        + cast(MemoryTuple, c2[lineno])[2],  # occurrences
+                    )
+                elif c1[lineno]:
+                    udf_code_map[lineno] = cast(MemoryTuple, c1[lineno])
+                elif c2[lineno]:
+                    udf_code_map[lineno] = cast(MemoryTuple, c2[lineno])
+                else:
+                    udf_code_map[lineno] = None
+            value1[filename] = [(k, v) for k, v in udf_code_map.items()]
+        return value1
+
+
 class BasicProfiler(Profiler):
     """
     BasicProfiler is the default profiler, which is implemented based on
@@ -172,7 +318,7 @@ class BasicProfiler(Profiler):
     """
 
     def __init__(self, ctx: "SparkContext") -> None:
-        Profiler.__init__(self, ctx)
+        super().__init__(ctx)
         # Creates a new accumulator for combining the profiles of different
         # partitions of a stage
         self._accumulator = ctx.accumulator(None, PStatsParam)  # type: ignore[arg-type]
@@ -193,6 +339,24 @@ class BasicProfiler(Profiler):
     def stats(self) -> pstats.Stats:
         return cast(pstats.Stats, self._accumulator.value)
 
+    def show(self, id: int) -> None:
+        """Print the profile stats to stdout, id is the RDD id"""
+        stats = self.stats()
+        if stats:
+            print("=" * 60)
+            print("Profile of RDD<id=%d>" % id)
+            print("=" * 60)
+            stats.sort_stats("time", "cumulative").print_stats()
+
+    def dump(self, id: int, path: str) -> None:
+        """Dump the profile into path, id is the RDD id"""
+        if not os.path.exists(path):
+            os.makedirs(path)
+        stats = self.stats()
+        if stats:
+            p = os.path.join(path, "rdd_%d.pstats" % id)
+            stats.dump_stats(p)
+
 
 class UDFBasicProfiler(BasicProfiler):
     """
@@ -218,6 +382,103 @@ class UDFBasicProfiler(BasicProfiler):
             stats.dump_stats(p)
 
 
+class MemoryProfiler(Profiler):
+    """
+    MemoryProfiler, which is implemented based on memory profiler and Accumulator
+    """
+
+    def __init__(self, ctx: "SparkContext") -> None:
+        super().__init__(ctx)
+        # Creates a new accumulator for combining the profiles
+        self._accumulator = ctx.accumulator(None, MemUsageParam)  # type: ignore[arg-type]
+
+    def profile(  # type: ignore
+        self,
+        sub_lines: Optional[List],
+        start_line: Optional[int],
+        func: Callable[..., Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Any:
+        """Runs and profiles the method func passed in. A profile object is returned."""
+        if has_memory_profiler:
+            profiler = UDFLineProfiler()
+            wrapped = profiler(func, sub_lines=sub_lines, start_line=start_line)
+            ret = wrapped(*args, **kwargs)
+            codemap_dict = {
+                filename: list(line_iterator)
+                for filename, line_iterator in profiler.code_map.items()
+            }
+            # Adds a new profile to the existing accumulated value
+            self._accumulator.add(codemap_dict)  # type: ignore[arg-type]
+            return ret
+        else:
+            raise RuntimeError(
+                "Install the 'memory_profiler' library in the cluster to enable memory profiling."
+            )
+
+    def stats(self) -> CodeMapDict:
+        """Return the collected memory profiles"""
+        return cast(CodeMapDict, self._accumulator.value)
+
+    def _show_results(
+        self, code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
+    ) -> None:
+        if stream is None:
+            stream = sys.stdout
+        template = "{0:>6} {1:>12} {2:>12}  {3:>10}   {4:<}"
+
+        for (filename, lines) in code_map.items():
+            header = template.format(
+                "Line #", "Mem usage", "Increment", "Occurrences", "Line Contents"
+            )
+
+            stream.write("Filename: " + filename + "\n\n")
+            stream.write(header + "\n")
+            stream.write("=" * len(header) + "\n")
+
+            all_lines = linecache.getlines(filename)
+
+            float_format = "{0}.{1}f".format(precision + 4, precision)
+            template_mem = "{0:" + float_format + "} MiB"
+            for (lineno, mem) in lines:
+                total_mem: Union[float, str]
+                inc: Union[float, str]
+                occurrences: Union[float, str]
+                if mem:
+                    inc = mem[0]
+                    total_mem = mem[1]
+                    total_mem = template_mem.format(total_mem)
+                    occurrences = mem[2]
+                    inc = template_mem.format(inc)
+                else:
+                    total_mem = ""
+                    inc = ""
+                    occurrences = ""
+                tmp = template.format(lineno, total_mem, inc, occurrences, all_lines[lineno - 1])
+                stream.write(tmp)
+            stream.write("\n\n")
+
+    def show(self, id: int) -> None:
+        """Print the profile stats to stdout, id is the PythonUDF id"""
+        code_map = self.stats()
+        if code_map:
+            print("=" * 60)
+            print("Profile of UDF<id=%d>" % id)
+            print("=" * 60)
+            self._show_results(code_map)
+
+    def dump(self, id: int, path: str) -> None:
+        """Dump the memory profile into path, id is the PythonUDF id"""
+        if not os.path.exists(path):
+            os.makedirs(path)
+        stats = self.stats()  # dict
+        if stats:
+            p = os.path.join(path, "udf_%d_memory.txt" % id)
+            with open(p, "w+") as f:
+                self._show_results(stats, stream=f)
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 5f4f4d494e1..7f5e4e603f4 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -5430,7 +5430,10 @@ class PipelinedRDD(RDD[U], Generic[T, U]):
         if self._bypass_serializer:
             self._jrdd_deserializer = NoOpSerializer()
 
-        if self.ctx.profiler_collector:
+        if (
+            self.ctx.profiler_collector
+            and self.ctx._conf.get("spark.python.profile", "false") == "true"
+        ):
             profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
         else:
             profiler = None
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index da9a245bb71..7c7be392cd3 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -20,6 +20,7 @@ User-defined function related classes and functions
 import functools
 import inspect
 import sys
+import warnings
 from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union
 
 from py4j.java_gateway import JavaObject
@@ -236,25 +237,66 @@ class UserDefinedFunction:
         sc = SparkContext._active_spark_context
         assert sc is not None
         profiler: Optional[Profiler] = None
+        memory_profiler: Optional[Profiler] = None
         if sc.profiler_collector:
-            f = self.func
-            profiler = sc.profiler_collector.new_udf_profiler(sc)
-
-            @functools.wraps(f)
-            def func(*args: Any, **kwargs: Any) -> Any:
-                assert profiler is not None
-                return profiler.profile(f, *args, **kwargs)
+            profiler_enabled = sc._conf.get("spark.python.profile", "false") == "true"
+            memory_profiler_enabled = sc._conf.get("spark.python.profile.memory", "false") == "true"
+
+            # Disable profiling Pandas UDFs with iterators as input/output.
+            if profiler_enabled or memory_profiler_enabled:
+                if self.evalType in [
+                    PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+                    PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+                    PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+                ]:
+                    profiler_enabled = memory_profiler_enabled = False
+                    warnings.warn(
+                        "Profiling UDFs with iterators input/output is not supported.",
+                        UserWarning,
+                    )
 
-            func.__signature__ = inspect.signature(f)  # type: ignore[attr-defined]
+            # Disallow enabling two profilers at the same time.
+            if profiler_enabled and memory_profiler_enabled:
+                # When both profilers are enabled, they interfere with each other,
+                # that makes the result profile misleading.
+                raise RuntimeError(
+                    "'spark.python.profile' and 'spark.python.profile.memory' configuration"
+                    " cannot be enabled together."
+                )
+            elif profiler_enabled:
+                f = self.func
+                profiler = sc.profiler_collector.new_udf_profiler(sc)
+
+                @functools.wraps(f)
+                def func(*args: Any, **kwargs: Any) -> Any:
+                    assert profiler is not None
+                    return profiler.profile(f, *args, **kwargs)
+
+                func.__signature__ = inspect.signature(f)  # type: ignore[attr-defined]
+                judf = self._create_judf(func)
+                jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
+                id = jPythonUDF.expr().resultId().id()
+                sc.profiler_collector.add_profiler(id, profiler)
+            else:  # memory_profiler_enabled
+                f = self.func
+                memory_profiler = sc.profiler_collector.new_memory_profiler(sc)
+                (sub_lines, start_line) = inspect.getsourcelines(f.__code__)
+
+                @functools.wraps(f)
+                def func(*args: Any, **kwargs: Any) -> Any:
+                    assert memory_profiler is not None
+                    return memory_profiler.profile(
+                        sub_lines, start_line, f, *args, **kwargs  # type: ignore[arg-type]
+                    )
 
-            judf = self._create_judf(func)
+                func.__signature__ = inspect.signature(f)  # type: ignore[attr-defined]
+                judf = self._create_judf(func)
+                jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
+                id = jPythonUDF.expr().resultId().id()
+                sc.profiler_collector.add_profiler(id, memory_profiler)
         else:
             judf = self._judf
-
-        jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
-        if profiler is not None:
-            id = jPythonUDF.expr().resultId().id()
-            sc.profiler_collector.add_profiler(id, profiler)
+            jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
         return Column(jPythonUDF)
 
     # This function is for improving the online help system in the interactive interpreter.
diff --git a/python/pyspark/tests/test_memory_profiler.py b/python/pyspark/tests/test_memory_profiler.py
new file mode 100644
index 00000000000..7da82dccb37
--- /dev/null
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -0,0 +1,160 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import tempfile
+import unittest
+import warnings
+from io import StringIO
+from typing import Iterator
+from unittest import mock
+
+import pandas as pd
+
+from pyspark import SparkConf, SparkContext
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import pandas_udf, udf
+from pyspark.testing.utils import PySparkTestCase
+
+try:
+    import memory_profiler  # type: ignore[import] # noqa: F401
+
+    has_memory_profiler = True
+except Exception:
+    has_memory_profiler = False
+
+
+@unittest.skipIf(not has_memory_profiler, "Must have memory-profiler installed.")
+class MemoryProfilerTests(PySparkTestCase):
+    def setUp(self):
+        self._old_sys_path = list(sys.path)
+        class_name = self.__class__.__name__
+        conf = SparkConf().set("spark.python.profile.memory", "true")
+        self.sc = SparkContext("local[4]", class_name, conf=conf)
+        self.spark = SparkSession(sparkContext=self.sc)
+
+    def test_memory_profiler(self):
+        self.exec_python_udf()
+
+        profilers = self.sc.profiler_collector.profilers
+        self.assertEqual(1, len(profilers))
+        id, profiler, _ = profilers[0]
+        stats = profiler.stats()
+        self.assertTrue(stats is not None)
+
+        with mock.patch("sys.stdout", new=StringIO()) as fake_out:
+            self.sc.show_profiles()
+        self.assertTrue("plus_one" in fake_out.getvalue())
+
+        d = tempfile.gettempdir()
+        self.sc.dump_profiles(d)
+        self.assertTrue("udf_%d_memory.txt" % id in os.listdir(d))
+
+    def test_profile_pandas_udf(self):
+        udfs = [self.exec_pandas_udf_ser_to_ser, self.exec_pandas_udf_ser_to_scalar]
+        udf_names = ["ser_to_ser", "ser_to_scalar"]
+        for f, f_name in zip(udfs, udf_names):
+            f()
+            with mock.patch("sys.stdout", new=StringIO()) as fake_out:
+                self.sc.show_profiles()
+            self.assertTrue(f_name in fake_out.getvalue())
+
+        with warnings.catch_warnings(record=True) as warns:
+            warnings.simplefilter("always")
+            self.exec_pandas_udf_iter_to_iter()
+            user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+            self.assertTrue(len(user_warns) > 0)
+            self.assertTrue(
+                "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0])
+            )
+
+    def test_profile_pandas_function_api(self):
+        apis = [self.exec_grouped_map]
+        f_names = ["grouped_map"]
+        for api, f_name in zip(apis, f_names):
+            api()
+            with mock.patch("sys.stdout", new=StringIO()) as fake_out:
+                self.sc.show_profiles()
+            self.assertTrue(f_name in fake_out.getvalue())
+
+        with warnings.catch_warnings(record=True) as warns:
+            warnings.simplefilter("always")
+            self.exec_map()
+            user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+            self.assertTrue(len(user_warns) > 0)
+            self.assertTrue(
+                "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0])
+            )
+
+    def exec_python_udf(self):
+        @udf("int")
+        def plus_one(v):
+            return v + 1
+
+        self.spark.range(10).select(plus_one("id")).collect()
+
+    def exec_pandas_udf_ser_to_ser(self):
+        @pandas_udf("int")
+        def ser_to_ser(ser: pd.Series) -> pd.Series:
+            return ser + 1
+
+        self.spark.range(10).select(ser_to_ser("id")).collect()
+
+    def exec_pandas_udf_ser_to_scalar(self):
+        @pandas_udf("int")
+        def ser_to_scalar(ser: pd.Series) -> float:
+            return ser.median()
+
+        self.spark.range(10).select(ser_to_scalar("id")).collect()
+
+    # Unsupported
+    def exec_pandas_udf_iter_to_iter(self):
+        @pandas_udf("int")
+        def iter_to_iter(batch_ser: Iterator[pd.Series]) -> Iterator[pd.Series]:
+            for ser in batch_ser:
+                yield ser + 1
+
+        self.spark.range(10).select(iter_to_iter("id")).collect()
+
+    def exec_grouped_map(self):
+        def grouped_map(pdf: pd.DataFrame) -> pd.DataFrame:
+            return pdf.assign(v=pdf.v - pdf.v.mean())
+
+        df = self.spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0)], ("id", "v"))
+        df.groupby("id").applyInPandas(grouped_map, schema="id long, v double").collect()
+
+    # Unsupported
+    def exec_map(self):
+        def map(pdfs: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
+            for pdf in pdfs:
+                yield pdf[pdf.id == 1]
+
+        df = self.spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0)], ("id", "v"))
+        df.mapInPandas(map, schema=df.schema).collect()
+
+
+if __name__ == "__main__":
+    from pyspark.tests.test_memory_profiler import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py
index d13e4ad5683..ceae904ca6f 100644
--- a/python/pyspark/tests/test_profiler.py
+++ b/python/pyspark/tests/test_profiler.py
@@ -22,6 +22,9 @@ import unittest
 from io import StringIO
 
 from pyspark import SparkConf, SparkContext, BasicProfiler
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import udf
+from pyspark.sql.utils import PythonException
 from pyspark.testing.utils import PySparkTestCase
 
 
@@ -82,21 +85,69 @@ class ProfilerTests(PySparkTestCase):
 
 class ProfilerTests2(unittest.TestCase):
     def test_profiler_disabled(self):
-        sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
+        sc = SparkContext(
+            conf=SparkConf()
+            .set("spark.python.profile", "false")
+            .set("spark.python.profile.memory", "false")
+        )
         try:
             self.assertRaisesRegex(
                 RuntimeError,
-                "'spark.python.profile' configuration must be set",
+                "'spark.python.profile' or 'spark.python.profile.memory' configuration must be set",
                 lambda: sc.show_profiles(),
             )
             self.assertRaisesRegex(
                 RuntimeError,
-                "'spark.python.profile' configuration must be set",
+                "'spark.python.profile' or 'spark.python.profile.memory' configuration must be set",
                 lambda: sc.dump_profiles("/tmp/abc"),
             )
         finally:
             sc.stop()
 
+    def test_profiler_all_enabled(self):
+        sc = SparkContext(
+            conf=SparkConf()
+            .set("spark.python.profile", "true")
+            .set("spark.python.profile.memory", "true")
+        )
+        spark = SparkSession(sparkContext=sc)
+
+        @udf("int")
+        def plus_one(v):
+            return v + 1
+
+        try:
+            self.assertRaisesRegex(
+                RuntimeError,
+                "'spark.python.profile' and 'spark.python.profile.memory' configuration"
+                " cannot be enabled together",
+                lambda: spark.range(10).select(plus_one("id")).collect(),
+            )
+        finally:
+            sc.stop()
+
+    def test_no_memory_profile_installed(self):
+        sc = SparkContext(
+            conf=SparkConf()
+            .set("spark.python.profile", "false")
+            .set("spark.python.profile.memory", "true")
+        )
+        spark = SparkSession(sparkContext=sc)
+
+        @udf("int")
+        def plus_one(v):
+            return v + 1
+
+        try:
+            self.assertRaisesRegex(
+                PythonException,
+                "Install the 'memory_profiler' library in the cluster to enable memory "
+                "profiling",
+                lambda: spark.range(10).select(plus_one("id")).collect(),
+            )
+        finally:
+            sc.stop()
+
 
 if __name__ == "__main__":
     from pyspark.tests.test_profiler import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org