You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/09/16 04:28:24 UTC

git commit: [SPARK-1087] Move python traceback utilities into new traceback_utils.py file.

Repository: spark
Updated Branches:
  refs/heads/master da33acb8b -> 60050f428


[SPARK-1087] Move python traceback utilities into new traceback_utils.py file.

Also made some cosmetic cleanups.

Author: Aaron Staple <aa...@gmail.com>

Closes #2385 from staple/SPARK-1087 and squashes the following commits:

7b3bb13 [Aaron Staple] Address review comments, cosmetic cleanups.
10ba6e1 [Aaron Staple] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/60050f42
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/60050f42
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/60050f42

Branch: refs/heads/master
Commit: 60050f42885582a699fc7a6fa0529964162bb8a3
Parents: da33acb
Author: Aaron Staple <aa...@gmail.com>
Authored: Mon Sep 15 19:28:17 2014 -0700
Committer: Josh Rosen <jo...@apache.org>
Committed: Mon Sep 15 19:28:17 2014 -0700

----------------------------------------------------------------------
 python/pyspark/context.py         |  8 +---
 python/pyspark/rdd.py             | 58 ++-----------------------
 python/pyspark/traceback_utils.py | 78 ++++++++++++++++++++++++++++++++++
 3 files changed, 83 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/60050f42/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ea28e8c..a33aae8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,7 +20,6 @@ import shutil
 import sys
 from threading import Lock
 from tempfile import NamedTemporaryFile
-from collections import namedtuple
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
@@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
 from pyspark.storagelevel import StorageLevel
 from pyspark import rdd
 from pyspark.rdd import RDD
+from pyspark.traceback_utils import CallSite, first_spark_call
 
 from py4j.java_collections import ListConverter
 
@@ -99,11 +99,7 @@ class SparkContext(object):
             ...
         ValueError:...
         """
-        if rdd._extract_concise_traceback() is not None:
-            self._callsite = rdd._extract_concise_traceback()
-        else:
-            tempNamedTuple = namedtuple("Callsite", "function file linenum")
-            self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
+        self._callsite = first_spark_call() or CallSite(None, None, None)
         SparkContext._ensure_initialized(self, gateway=gateway)
         try:
             self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,

http://git-wip-us.apache.org/repos/asf/spark/blob/60050f42/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6ad5ab2..21f182b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,13 +18,11 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from collections import namedtuple
 from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
 import shlex
-import traceback
 from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
@@ -45,6 +43,7 @@ from pyspark.storagelevel import StorageLevel
 from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
     get_used_memory, ExternalSorter
+from pyspark.traceback_utils import SCCallSiteSync
 
 from py4j.java_collections import ListConverter, MapConverter
 
@@ -81,57 +80,6 @@ def portable_hash(x):
     return hash(x)
 
 
-def _extract_concise_traceback():
-    """
-    This function returns the traceback info for a callsite, returns a dict
-    with function name, file name and line number
-    """
-    tb = traceback.extract_stack()
-    callsite = namedtuple("Callsite", "function file linenum")
-    if len(tb) == 0:
-        return None
-    file, line, module, what = tb[len(tb) - 1]
-    sparkpath = os.path.dirname(file)
-    first_spark_frame = len(tb) - 1
-    for i in range(0, len(tb)):
-        file, line, fun, what = tb[i]
-        if file.startswith(sparkpath):
-            first_spark_frame = i
-            break
-    if first_spark_frame == 0:
-        file, line, fun, what = tb[0]
-        return callsite(function=fun, file=file, linenum=line)
-    sfile, sline, sfun, swhat = tb[first_spark_frame]
-    ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
-    return callsite(function=sfun, file=ufile, linenum=uline)
-
-_spark_stack_depth = 0
-
-
-class _JavaStackTrace(object):
-
-    def __init__(self, sc):
-        tb = _extract_concise_traceback()
-        if tb is not None:
-            self._traceback = "%s at %s:%s" % (
-                tb.function, tb.file, tb.linenum)
-        else:
-            self._traceback = "Error! Could not extract traceback info"
-        self._context = sc
-
-    def __enter__(self):
-        global _spark_stack_depth
-        if _spark_stack_depth == 0:
-            self._context._jsc.setCallSite(self._traceback)
-        _spark_stack_depth += 1
-
-    def __exit__(self, type, value, tb):
-        global _spark_stack_depth
-        _spark_stack_depth -= 1
-        if _spark_stack_depth == 0:
-            self._context._jsc.setCallSite(None)
-
-
 class BoundedFloat(float):
     """
     Bounded value is generated by approximate job, with confidence and low
@@ -704,7 +652,7 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        with _JavaStackTrace(self.context) as st:
+        with SCCallSiteSync(self.context) as css:
             bytesInJava = self._jrdd.collect().iterator()
         return list(self._collect_iterator_through_file(bytesInJava))
 
@@ -1515,7 +1463,7 @@ class RDD(object):
 
         keyed = self.mapPartitionsWithIndex(add_shuffle_key)
         keyed._bypass_serializer = True
-        with _JavaStackTrace(self.context) as st:
+        with SCCallSiteSync(self.context) as css:
             pairRDD = self.ctx._jvm.PairwiseRDD(
                 keyed._jrdd.rdd()).asJavaPairRDD()
             partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,

http://git-wip-us.apache.org/repos/asf/spark/blob/60050f42/python/pyspark/traceback_utils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py
new file mode 100644
index 0000000..bb8646d
--- /dev/null
+++ b/python/pyspark/traceback_utils.py
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+import os
+import traceback
+
+
+CallSite = namedtuple("CallSite", "function file linenum")
+
+
+def first_spark_call():
+    """
+    Return a CallSite representing the first Spark call in the current call stack.
+    """
+    tb = traceback.extract_stack()
+    if len(tb) == 0:
+        return None
+    file, line, module, what = tb[len(tb) - 1]
+    sparkpath = os.path.dirname(file)
+    first_spark_frame = len(tb) - 1
+    for i in range(0, len(tb)):
+        file, line, fun, what = tb[i]
+        if file.startswith(sparkpath):
+            first_spark_frame = i
+            break
+    if first_spark_frame == 0:
+        file, line, fun, what = tb[0]
+        return CallSite(function=fun, file=file, linenum=line)
+    sfile, sline, sfun, swhat = tb[first_spark_frame]
+    ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
+    return CallSite(function=sfun, file=ufile, linenum=uline)
+
+
+class SCCallSiteSync(object):
+    """
+    Helper for setting the spark context call site.
+
+    Example usage:
+    from pyspark.context import SCCallSiteSync
+    with SCCallSiteSync(<relevant SparkContext>) as css:
+        <a Spark call>
+    """
+
+    _spark_stack_depth = 0
+
+    def __init__(self, sc):
+        call_site = first_spark_call()
+        if call_site is not None:
+            self._call_site = "%s at %s:%s" % (
+                call_site.function, call_site.file, call_site.linenum)
+        else:
+            self._call_site = "Error! Could not extract traceback info"
+        self._context = sc
+
+    def __enter__(self):
+        if SCCallSiteSync._spark_stack_depth == 0:
+            self._context._jsc.setCallSite(self._call_site)
+        SCCallSiteSync._spark_stack_depth += 1
+
+    def __exit__(self, type, value, tb):
+        SCCallSiteSync._spark_stack_depth -= 1
+        if SCCallSiteSync._spark_stack_depth == 0:
+            self._context._jsc.setCallSite(None)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org